mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
Compare commits
27 Commits
devin/1769
...
gl/feat/no
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce8f55f | ||
|
|
1d40f5d83c | ||
|
|
3afac2a696 | ||
|
|
5fab437b7f | ||
|
|
30684f387e | ||
|
|
f2b4efe7fa | ||
|
|
4f175fdd6f | ||
|
|
d72b79f932 | ||
|
|
e8638d318d | ||
|
|
d2c880c6b3 | ||
|
|
087f6d25a9 | ||
|
|
c57e325482 | ||
|
|
fdb7047780 | ||
|
|
adb485f7f7 | ||
|
|
ee64bd426e | ||
|
|
37b80ee937 | ||
|
|
bf9ccd418a | ||
|
|
bd95356ec5 | ||
|
|
441591d592 | ||
|
|
132b6b224a | ||
|
|
4e2916d71a | ||
|
|
0c4a0e1fda | ||
|
|
9c4126e0d8 | ||
|
|
5156fc4792 | ||
|
|
c600b26ca6 | ||
|
|
162a106002 | ||
|
|
be33c8e3e5 |
@@ -38,6 +38,7 @@ dependencies = [
|
|||||||
"pydantic-settings~=2.10.1",
|
"pydantic-settings~=2.10.1",
|
||||||
"mcp~=1.16.0",
|
"mcp~=1.16.0",
|
||||||
"uv~=0.9.13",
|
"uv~=0.9.13",
|
||||||
|
"aiosqlite~=0.21.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
import json
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.a2a.config import A2AConfig
|
from crewai.a2a.config import A2AConfig
|
||||||
|
from crewai.agent.utils import (
|
||||||
|
ahandle_knowledge_retrieval,
|
||||||
|
apply_training_data,
|
||||||
|
build_task_prompt_with_schema,
|
||||||
|
format_task_with_context,
|
||||||
|
get_knowledge_config,
|
||||||
|
handle_knowledge_retrieval,
|
||||||
|
handle_reasoning,
|
||||||
|
prepare_tools,
|
||||||
|
process_tool_results,
|
||||||
|
save_last_messages,
|
||||||
|
validate_max_execution_time,
|
||||||
|
)
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
|||||||
KnowledgeQueryCompletedEvent,
|
KnowledgeQueryCompletedEvent,
|
||||||
KnowledgeQueryFailedEvent,
|
KnowledgeQueryFailedEvent,
|
||||||
KnowledgeQueryStartedEvent,
|
KnowledgeQueryStartedEvent,
|
||||||
KnowledgeRetrievalCompletedEvent,
|
|
||||||
KnowledgeRetrievalStartedEvent,
|
|
||||||
KnowledgeSearchQueryFailedEvent,
|
|
||||||
)
|
)
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
|||||||
)
|
)
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
|
||||||
from crewai.lite_agent import LiteAgent
|
from crewai.lite_agent import LiteAgent
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.mcp import (
|
from crewai.mcp import (
|
||||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
)
|
)
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import Converter, generate_model_description
|
from crewai.utilities.converter import Converter
|
||||||
from crewai.utilities.guardrail_types import GuardrailType
|
from crewai.utilities.guardrail_types import GuardrailType
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.prompts import Prompts
|
from crewai.utilities.prompts import Prompts
|
||||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
|||||||
ValueError: If the max execution time is not a positive integer.
|
ValueError: If the max execution time is not a positive integer.
|
||||||
RuntimeError: If the agent execution fails for other reasons.
|
RuntimeError: If the agent execution fails for other reasons.
|
||||||
"""
|
"""
|
||||||
if self.reasoning:
|
handle_reasoning(self, task)
|
||||||
try:
|
|
||||||
from crewai.utilities.reasoning_handler import (
|
|
||||||
AgentReasoning,
|
|
||||||
AgentReasoningOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
|
||||||
reasoning_output: AgentReasoningOutput = (
|
|
||||||
reasoning_handler.handle_agent_reasoning()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the reasoning plan to the task description
|
|
||||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
|
||||||
self._inject_date_to_task(task)
|
self._inject_date_to_task(task)
|
||||||
|
|
||||||
if self.tools_handler:
|
if self.tools_handler:
|
||||||
self.tools_handler.last_used_tool = None
|
self.tools_handler.last_used_tool = None
|
||||||
|
|
||||||
task_prompt = task.prompt()
|
task_prompt = task.prompt()
|
||||||
|
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||||
# If the task requires output in JSON or Pydantic format,
|
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||||
# append specific instructions to the task prompt to ensure
|
|
||||||
# that the final answer does not include any code block markers
|
|
||||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
|
||||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
|
||||||
# Generate the schema based on the output format
|
|
||||||
if task.output_json:
|
|
||||||
schema_dict = generate_model_description(task.output_json)
|
|
||||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
|
||||||
"formatted_task_instructions"
|
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
elif task.output_pydantic:
|
|
||||||
schema_dict = generate_model_description(task.output_pydantic)
|
|
||||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
|
||||||
"formatted_task_instructions"
|
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
if context:
|
|
||||||
task_prompt = self.i18n.slice("task_with_context").format(
|
|
||||||
task=task_prompt, context=context
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._is_any_available_memory():
|
if self._is_any_available_memory():
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
|||||||
from_task=task,
|
from_task=task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
knowledge_config = (
|
|
||||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
knowledge_config = get_knowledge_config(self)
|
||||||
|
task_prompt = handle_knowledge_retrieval(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
task_prompt,
|
||||||
|
knowledge_config,
|
||||||
|
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||||
|
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
prepare_tools(self, tools, task)
|
||||||
crewai_event_bus.emit(
|
task_prompt = apply_training_data(self, task_prompt)
|
||||||
self,
|
|
||||||
event=KnowledgeRetrievalStartedEvent(
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
|
||||||
task_prompt, task
|
|
||||||
)
|
|
||||||
if self.knowledge_search_query:
|
|
||||||
# Quering agent specific knowledge
|
|
||||||
if self.knowledge:
|
|
||||||
agent_knowledge_snippets = self.knowledge.query(
|
|
||||||
[self.knowledge_search_query], **knowledge_config
|
|
||||||
)
|
|
||||||
if agent_knowledge_snippets:
|
|
||||||
self.agent_knowledge_context = extract_knowledge_context(
|
|
||||||
agent_knowledge_snippets
|
|
||||||
)
|
|
||||||
if self.agent_knowledge_context:
|
|
||||||
task_prompt += self.agent_knowledge_context
|
|
||||||
|
|
||||||
# Quering crew specific knowledge
|
|
||||||
knowledge_snippets = self.crew.query_knowledge(
|
|
||||||
[self.knowledge_search_query], **knowledge_config
|
|
||||||
)
|
|
||||||
if knowledge_snippets:
|
|
||||||
self.crew_knowledge_context = extract_knowledge_context(
|
|
||||||
knowledge_snippets
|
|
||||||
)
|
|
||||||
if self.crew_knowledge_context:
|
|
||||||
task_prompt += self.crew_knowledge_context
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=KnowledgeRetrievalCompletedEvent(
|
|
||||||
query=self.knowledge_search_query,
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
retrieved_knowledge=(
|
|
||||||
(self.agent_knowledge_context or "")
|
|
||||||
+ (
|
|
||||||
"\n"
|
|
||||||
if self.agent_knowledge_context
|
|
||||||
and self.crew_knowledge_context
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
+ (self.crew_knowledge_context or "")
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=KnowledgeSearchQueryFailedEvent(
|
|
||||||
query=self.knowledge_search_query or "",
|
|
||||||
error=str(e),
|
|
||||||
from_task=task,
|
|
||||||
from_agent=self,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = tools or self.tools or []
|
|
||||||
self.create_agent_executor(tools=tools, task=task)
|
|
||||||
|
|
||||||
if self.crew and self.crew._train:
|
|
||||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
|
||||||
else:
|
|
||||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
|
||||||
|
|
||||||
# Import agent events locally to avoid circular imports
|
|
||||||
from crewai.events.types.agent_events import (
|
from crewai.events.types.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
AgentExecutionErrorEvent,
|
AgentExecutionErrorEvent,
|
||||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine execution method based on timeout setting
|
validate_max_execution_time(self.max_execution_time)
|
||||||
if self.max_execution_time is not None:
|
if self.max_execution_time is not None:
|
||||||
if (
|
|
||||||
not isinstance(self.max_execution_time, int)
|
|
||||||
or self.max_execution_time <= 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Max Execution time must be a positive integer greater than zero"
|
|
||||||
)
|
|
||||||
result = self._execute_with_timeout(
|
result = self._execute_with_timeout(
|
||||||
task_prompt, task, self.max_execution_time
|
task_prompt, task, self.max_execution_time
|
||||||
)
|
)
|
||||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
|||||||
result = self._execute_without_timeout(task_prompt, task)
|
result = self._execute_without_timeout(task_prompt, task)
|
||||||
|
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
# Propagate TimeoutError without retry
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionErrorEvent(
|
event=AgentExecutionErrorEvent(
|
||||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
|||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if e.__class__.__module__.startswith("litellm"):
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
# Do not retry on litellm errors
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionErrorEvent(
|
event=AgentExecutionErrorEvent(
|
||||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
|||||||
if self.max_rpm and self._rpm_controller:
|
if self.max_rpm and self._rpm_controller:
|
||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
# If there was any tool in self.tools_results that had result_as_answer
|
result = process_tool_results(self, result)
|
||||||
# set to True, return the results of the last tool that had
|
|
||||||
# result_as_answer set to True
|
|
||||||
for tool_result in self.tools_results:
|
|
||||||
if tool_result.get("result_as_answer", False):
|
|
||||||
result = tool_result["result"]
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._last_messages = (
|
save_last_messages(self)
|
||||||
self.agent_executor.messages.copy()
|
|
||||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cleanup_mcp_clients()
|
self._cleanup_mcp_clients()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
|||||||
}
|
}
|
||||||
)["output"]
|
)["output"]
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Task,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a task with the agent asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task to execute.
|
||||||
|
context: Context to execute the task in.
|
||||||
|
tools: Tools to use for the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output of the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If execution exceeds the maximum execution time.
|
||||||
|
ValueError: If the max execution time is not a positive integer.
|
||||||
|
RuntimeError: If the agent execution fails for other reasons.
|
||||||
|
"""
|
||||||
|
handle_reasoning(self, task)
|
||||||
|
self._inject_date_to_task(task)
|
||||||
|
|
||||||
|
if self.tools_handler:
|
||||||
|
self.tools_handler.last_used_tool = None
|
||||||
|
|
||||||
|
task_prompt = task.prompt()
|
||||||
|
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||||
|
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||||
|
|
||||||
|
if self._is_any_available_memory():
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalStartedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
self.crew._short_term_memory,
|
||||||
|
self.crew._long_term_memory,
|
||||||
|
self.crew._entity_memory,
|
||||||
|
self.crew._external_memory,
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
memory = await contextual_memory.abuild_context_for_task(
|
||||||
|
task, context or ""
|
||||||
|
)
|
||||||
|
if memory.strip() != "":
|
||||||
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalCompletedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
memory_content=memory,
|
||||||
|
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
knowledge_config = get_knowledge_config(self)
|
||||||
|
task_prompt = await ahandle_knowledge_retrieval(
|
||||||
|
self, task, task_prompt, knowledge_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_tools(self, tools, task)
|
||||||
|
task_prompt = apply_training_data(self, task_prompt)
|
||||||
|
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
AgentExecutionCompletedEvent,
|
||||||
|
AgentExecutionErrorEvent,
|
||||||
|
AgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionStartedEvent(
|
||||||
|
agent=self,
|
||||||
|
tools=self.tools,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
task=task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_max_execution_time(self.max_execution_time)
|
||||||
|
if self.max_execution_time is not None:
|
||||||
|
result = await self._aexecute_with_timeout(
|
||||||
|
task_prompt, task, self.max_execution_time
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||||
|
|
||||||
|
except TimeoutError as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
self._times_executed += 1
|
||||||
|
if self._times_executed > self.max_retry_limit:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionErrorEvent(
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
result = await self.aexecute_task(task, context, tools)
|
||||||
|
|
||||||
|
if self.max_rpm and self._rpm_controller:
|
||||||
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
|
result = process_tool_results(self, result)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||||
|
)
|
||||||
|
|
||||||
|
save_last_messages(self)
|
||||||
|
self._cleanup_mcp_clients()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _aexecute_with_timeout(
|
||||||
|
self, task_prompt: str, task: Task, timeout: int
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a task with a timeout asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The prompt to send to the agent.
|
||||||
|
task: The task being executed.
|
||||||
|
timeout: Maximum execution time in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If execution exceeds the timeout.
|
||||||
|
RuntimeError: If execution fails for other reasons.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._aexecute_without_timeout(task_prompt, task),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||||
|
"Consider increasing max_execution_time or optimizing the task."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||||
|
"""Execute a task without a timeout asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The prompt to send to the agent.
|
||||||
|
task: The task being executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the agent.
|
||||||
|
"""
|
||||||
|
if not self.agent_executor:
|
||||||
|
raise RuntimeError("Agent executor is not initialized.")
|
||||||
|
|
||||||
|
result = await self.agent_executor.ainvoke(
|
||||||
|
{
|
||||||
|
"input": task_prompt,
|
||||||
|
"tool_names": self.agent_executor.tools_names,
|
||||||
|
"tools": self.agent_executor.tools_description,
|
||||||
|
"ask_for_human_input": task.human_input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result["output"]
|
||||||
|
|
||||||
def create_agent_executor(
|
def create_agent_executor(
|
||||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.agent_executor = CrewAgentExecutor(
|
self.agent_executor = CrewAgentExecutor(
|
||||||
llm=self.llm,
|
llm=self.llm, # type: ignore[arg-type]
|
||||||
task=task, # type: ignore[arg-type]
|
task=task, # type: ignore[arg-type]
|
||||||
agent=self,
|
agent=self,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
|||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||||
|
|
||||||
|
transport: StdioTransport | HTTPTransport | SSETransport
|
||||||
if isinstance(mcp_config, MCPServerStdio):
|
if isinstance(mcp_config, MCPServerStdio):
|
||||||
transport = StdioTransport(
|
transport = StdioTransport(
|
||||||
command=mcp_config.command,
|
command=mcp_config.command,
|
||||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
|||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
run_context=None,
|
run_context=None,
|
||||||
)
|
)
|
||||||
if mcp_config.tool_filter(context, tool):
|
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||||
filtered_tools.append(tool)
|
filtered_tools.append(tool)
|
||||||
except (TypeError, AttributeError):
|
except (TypeError, AttributeError):
|
||||||
if mcp_config.tool_filter(tool):
|
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||||
filtered_tools.append(tool)
|
filtered_tools.append(tool)
|
||||||
else:
|
else:
|
||||||
# Not callable - include tool
|
# Not callable - include tool
|
||||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
|||||||
path = parsed.path.replace("/", "_").strip("_")
|
path = parsed.path.replace("/", "_").strip("_")
|
||||||
return f"{domain}_{path}" if path else domain
|
return f"{domain}_{path}" if path else domain
|
||||||
|
|
||||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
def _get_mcp_tool_schemas(
|
||||||
|
self, server_params: dict[str, Any]
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
|
|
||||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
|||||||
self._logger.log(
|
self._logger.log(
|
||||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||||
)
|
)
|
||||||
return cached_data
|
return cached_data # type: ignore[no-any-return]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
async def _get_mcp_tool_schemas_async(
|
async def _get_mcp_tool_schemas_async(
|
||||||
self, server_params: dict[str, Any]
|
self, server_params: dict[str, Any]
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
return await self._retry_mcp_discovery(
|
return await self._retry_mcp_discovery(
|
||||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _retry_mcp_discovery(
|
async def _retry_mcp_discovery(
|
||||||
self, operation_func, server_url: str
|
self, operation_func: Any, server_url: str
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||||
last_error = None
|
last_error = None
|
||||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _attempt_mcp_discovery(
|
async def _attempt_mcp_discovery(
|
||||||
operation_func, server_url: str
|
operation_func: Any, server_url: str
|
||||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||||
try:
|
try:
|
||||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
|||||||
properties = json_schema.get("properties", {})
|
properties = json_schema.get("properties", {})
|
||||||
required_fields = json_schema.get("required", [])
|
required_fields = json_schema.get("required", [])
|
||||||
|
|
||||||
field_definitions = {}
|
field_definitions: dict[str, Any] = {}
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
for field_name, field_schema in properties.items():
|
||||||
field_type = self._json_type_to_python(field_schema)
|
field_type = self._json_type_to_python(field_schema)
|
||||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||||
return create_model(model_name, **field_definitions)
|
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||||
|
|
||||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||||
"""Convert JSON Schema type to Python type.
|
"""Convert JSON Schema type to Python type.
|
||||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
|||||||
json_type = field_schema.get("type")
|
json_type = field_schema.get("type")
|
||||||
|
|
||||||
if "anyOf" in field_schema:
|
if "anyOf" in field_schema:
|
||||||
types = []
|
types: list[type] = []
|
||||||
for option in field_schema["anyOf"]:
|
for option in field_schema["anyOf"]:
|
||||||
if "const" in option:
|
if "const" in option:
|
||||||
types.append(str)
|
types.append(str)
|
||||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
|||||||
types.append(self._json_type_to_python(option))
|
types.append(self._json_type_to_python(option))
|
||||||
unique_types = list(set(types))
|
unique_types = list(set(types))
|
||||||
if len(unique_types) > 1:
|
if len(unique_types) > 1:
|
||||||
result = unique_types[0]
|
result: Any = unique_types[0]
|
||||||
for t in unique_types[1:]:
|
for t in unique_types[1:]:
|
||||||
result = result | t
|
result = result | t
|
||||||
return result
|
return result # type: ignore[no-any-return]
|
||||||
return unique_types[0]
|
return unique_types[0]
|
||||||
|
|
||||||
type_mapping = {
|
type_mapping: dict[str | None, type] = {
|
||||||
"string": str,
|
"string": str,
|
||||||
"number": float,
|
"number": float,
|
||||||
"integer": int,
|
"integer": int,
|
||||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
|||||||
return type_mapping.get(json_type, Any)
|
return type_mapping.get(json_type, Any)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||||
# Should return list of server configs with URLs
|
# Should return list of server configs with URLs
|
||||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
if self.apps:
|
if self.apps:
|
||||||
platform_tools = self.get_platform_tools(self.apps)
|
platform_tools = self.get_platform_tools(self.apps)
|
||||||
if platform_tools:
|
if platform_tools and self.tools is not None:
|
||||||
self.tools.extend(platform_tools)
|
self.tools.extend(platform_tools)
|
||||||
if self.mcps:
|
if self.mcps:
|
||||||
mcps = self.get_mcp_tools(self.mcps)
|
mcps = self.get_mcp_tools(self.mcps)
|
||||||
if mcps:
|
if mcps and self.tools is not None:
|
||||||
self.tools.extend(mcps)
|
self.tools.extend(mcps)
|
||||||
|
|
||||||
lite_agent = LiteAgent(
|
lite_agent = LiteAgent(
|
||||||
|
|||||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""Utility functions for agent task execution.
|
||||||
|
|
||||||
|
This module contains shared logic extracted from the Agent's execute_task
|
||||||
|
and aexecute_task methods to reduce code duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.knowledge_events import (
|
||||||
|
KnowledgeRetrievalCompletedEvent,
|
||||||
|
KnowledgeRetrievalStartedEvent,
|
||||||
|
KnowledgeSearchQueryFailedEvent,
|
||||||
|
)
|
||||||
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
from crewai.utilities.i18n import I18N
|
||||||
|
|
||||||
|
|
||||||
|
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||||
|
"""Handle the reasoning process for an agent before task execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task to execute.
|
||||||
|
"""
|
||||||
|
if not agent.reasoning:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from crewai.utilities.reasoning_handler import (
|
||||||
|
AgentReasoning,
|
||||||
|
AgentReasoningOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||||
|
reasoning_output: AgentReasoningOutput = (
|
||||||
|
reasoning_handler.handle_agent_reasoning()
|
||||||
|
)
|
||||||
|
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||||
|
except Exception as e:
|
||||||
|
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||||
|
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The initial task prompt.
|
||||||
|
i18n: Internationalization instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with schema instructions.
|
||||||
|
"""
|
||||||
|
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||||
|
if task.output_json:
|
||||||
|
schema_dict = generate_model_description(task.output_json)
|
||||||
|
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||||
|
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||||
|
output_format=schema
|
||||||
|
)
|
||||||
|
elif task.output_pydantic:
|
||||||
|
schema_dict = generate_model_description(task.output_pydantic)
|
||||||
|
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||||
|
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||||
|
output_format=schema
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||||
|
"""Format task prompt with context if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The task prompt.
|
||||||
|
context: Optional context string.
|
||||||
|
i18n: Internationalization instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt formatted with context if provided.
|
||||||
|
"""
|
||||||
|
if context:
|
||||||
|
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||||
|
"""Get knowledge configuration from agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of knowledge configuration.
|
||||||
|
"""
|
||||||
|
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_knowledge_retrieval(
|
||||||
|
agent: Agent,
|
||||||
|
task: Task,
|
||||||
|
task_prompt: str,
|
||||||
|
knowledge_config: dict[str, Any],
|
||||||
|
query_func: Any,
|
||||||
|
crew_query_func: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Handle knowledge retrieval for task execution.
|
||||||
|
|
||||||
|
This function handles both agent-specific and crew-specific knowledge queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The current task prompt.
|
||||||
|
knowledge_config: Knowledge configuration dictionary.
|
||||||
|
query_func: Function to query agent knowledge (sync or async).
|
||||||
|
crew_query_func: Function to query crew knowledge (sync or async).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with knowledge context.
|
||||||
|
"""
|
||||||
|
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalStartedEvent(
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||||
|
task_prompt, task
|
||||||
|
)
|
||||||
|
if agent.knowledge_search_query:
|
||||||
|
if agent.knowledge:
|
||||||
|
agent_knowledge_snippets = query_func(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if agent_knowledge_snippets:
|
||||||
|
agent.agent_knowledge_context = extract_knowledge_context(
|
||||||
|
agent_knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.agent_knowledge_context:
|
||||||
|
task_prompt += agent.agent_knowledge_context
|
||||||
|
|
||||||
|
knowledge_snippets = crew_query_func(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if knowledge_snippets:
|
||||||
|
agent.crew_knowledge_context = extract_knowledge_context(
|
||||||
|
knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.crew_knowledge_context:
|
||||||
|
task_prompt += agent.crew_knowledge_context
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalCompletedEvent(
|
||||||
|
query=agent.knowledge_search_query,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeSearchQueryFailedEvent(
|
||||||
|
query=agent.knowledge_search_query or "",
|
||||||
|
error=str(e),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_knowledge_context(agent: Agent) -> str:
|
||||||
|
"""Combine agent and crew knowledge contexts into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with knowledge contexts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined knowledge context string.
|
||||||
|
"""
|
||||||
|
agent_ctx = agent.agent_knowledge_context or ""
|
||||||
|
crew_ctx = agent.crew_knowledge_context or ""
|
||||||
|
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||||
|
return agent_ctx + separator + crew_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||||
|
"""Apply training data to the task prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task_prompt: The task prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt with training data applied.
|
||||||
|
"""
|
||||||
|
if agent.crew and agent.crew._train:
|
||||||
|
return agent._training_handler(task_prompt=task_prompt)
|
||||||
|
return agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||||
|
"""Process tool results, returning result_as_answer if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with tool results.
|
||||||
|
result: The current result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final result, potentially overridden by tool result_as_answer.
|
||||||
|
"""
|
||||||
|
for tool_result in agent.tools_results:
|
||||||
|
if tool_result.get("result_as_answer", False):
|
||||||
|
result = tool_result["result"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def save_last_messages(agent: Agent) -> None:
|
||||||
|
"""Save the last messages from agent executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
"""
|
||||||
|
agent._last_messages = (
|
||||||
|
agent.agent_executor.messages.copy()
|
||||||
|
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tools(
|
||||||
|
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||||
|
) -> list[BaseTool]:
|
||||||
|
"""Prepare tools for task execution and create agent executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance.
|
||||||
|
tools: Optional list of tools.
|
||||||
|
task: The task being executed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of tools to use.
|
||||||
|
"""
|
||||||
|
final_tools = tools or agent.tools or []
|
||||||
|
agent.create_agent_executor(tools=final_tools, task=task)
|
||||||
|
return final_tools
|
||||||
|
|
||||||
|
|
||||||
|
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||||
|
"""Validate max_execution_time parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_execution_time: The maximum execution time to validate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If max_execution_time is not a positive integer.
|
||||||
|
"""
|
||||||
|
if max_execution_time is not None:
|
||||||
|
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Max Execution time must be a positive integer greater than zero"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def ahandle_knowledge_retrieval(
|
||||||
|
agent: Agent,
|
||||||
|
task: Task,
|
||||||
|
task_prompt: str,
|
||||||
|
knowledge_config: dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Handle async knowledge retrieval for task execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent performing the task.
|
||||||
|
task: The task being executed.
|
||||||
|
task_prompt: The current task prompt.
|
||||||
|
knowledge_config: Knowledge configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task prompt potentially augmented with knowledge context.
|
||||||
|
"""
|
||||||
|
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalStartedEvent(
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||||
|
task_prompt, task
|
||||||
|
)
|
||||||
|
if agent.knowledge_search_query:
|
||||||
|
if agent.knowledge:
|
||||||
|
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if agent_knowledge_snippets:
|
||||||
|
agent.agent_knowledge_context = extract_knowledge_context(
|
||||||
|
agent_knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.agent_knowledge_context:
|
||||||
|
task_prompt += agent.agent_knowledge_context
|
||||||
|
|
||||||
|
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||||
|
[agent.knowledge_search_query], **knowledge_config
|
||||||
|
)
|
||||||
|
if knowledge_snippets:
|
||||||
|
agent.crew_knowledge_context = extract_knowledge_context(
|
||||||
|
knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent.crew_knowledge_context:
|
||||||
|
task_prompt += agent.crew_knowledge_context
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeRetrievalCompletedEvent(
|
||||||
|
query=agent.knowledge_search_query,
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent,
|
||||||
|
event=KnowledgeSearchQueryFailedEvent(
|
||||||
|
query=agent.knowledge_search_query or "",
|
||||||
|
error=str(e),
|
||||||
|
from_task=task,
|
||||||
|
from_agent=agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
if not mcps:
|
if not mcps:
|
||||||
return mcps
|
return mcps
|
||||||
|
|
||||||
validated_mcps = []
|
validated_mcps: list[str | MCPServerConfig] = []
|
||||||
for mcp in mcps:
|
for mcp in mcps:
|
||||||
if isinstance(mcp, str):
|
if isinstance(mcp, str):
|
||||||
if mcp.startswith(("https://", "crewai-amp:")):
|
if mcp.startswith(("https://", "crewai-amp:")):
|
||||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
) -> str:
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Execute a task asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
|
|||||||
get_before_llm_call_hooks,
|
get_before_llm_call_hooks,
|
||||||
)
|
)
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
|
aget_llm_response,
|
||||||
enforce_rpm_limit,
|
enforce_rpm_limit,
|
||||||
format_message_for_llm,
|
format_message_for_llm,
|
||||||
get_llm_response,
|
get_llm_response,
|
||||||
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
|
|||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
from crewai.utilities.i18n import I18N, get_i18n
|
from crewai.utilities.i18n import I18N, get_i18n
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
from crewai.utilities.tool_utils import (
|
||||||
|
aexecute_tool_and_check_finality,
|
||||||
|
execute_tool_and_check_finality,
|
||||||
|
)
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.messages: list[LLMMessage] = []
|
self.messages: list[LLMMessage] = []
|
||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.log_error_after = 3
|
self.log_error_after = 3
|
||||||
self.before_llm_call_hooks: list[Callable] = []
|
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||||
self.after_llm_call_hooks: list[Callable] = []
|
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||||
if self.llm:
|
if self.llm:
|
||||||
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
|
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the agent asynchronously with given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input dictionary containing prompt variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with agent output.
|
||||||
|
"""
|
||||||
|
if "system" in self.prompt:
|
||||||
|
system_prompt = self._format_prompt(
|
||||||
|
cast(str, self.prompt.get("system", "")), inputs
|
||||||
|
)
|
||||||
|
user_prompt = self._format_prompt(
|
||||||
|
cast(str, self.prompt.get("user", "")), inputs
|
||||||
|
)
|
||||||
|
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||||
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
|
else:
|
||||||
|
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||||
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
|
|
||||||
|
self._show_start_logs()
|
||||||
|
|
||||||
|
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||||
|
|
||||||
|
try:
|
||||||
|
formatted_answer = await self._ainvoke_loop()
|
||||||
|
except AssertionError:
|
||||||
|
self._printer.print(
|
||||||
|
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
handle_unknown_error(self._printer, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if self.ask_for_human_input:
|
||||||
|
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||||
|
|
||||||
|
self._create_short_term_memory(formatted_answer)
|
||||||
|
self._create_long_term_memory(formatted_answer)
|
||||||
|
self._create_external_memory(formatted_answer)
|
||||||
|
return {"output": formatted_answer.output}
|
||||||
|
|
||||||
|
async def _ainvoke_loop(self) -> AgentFinish:
|
||||||
|
"""Execute agent loop asynchronously until completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final answer from the agent.
|
||||||
|
"""
|
||||||
|
formatted_answer = None
|
||||||
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
|
try:
|
||||||
|
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||||
|
formatted_answer = handle_max_iterations_exceeded(
|
||||||
|
formatted_answer,
|
||||||
|
printer=self._printer,
|
||||||
|
i18n=self._i18n,
|
||||||
|
messages=self.messages,
|
||||||
|
llm=self.llm,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
|
answer = await aget_llm_response(
|
||||||
|
llm=self.llm,
|
||||||
|
messages=self.messages,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
printer=self._printer,
|
||||||
|
from_task=self.task,
|
||||||
|
from_agent=self.agent,
|
||||||
|
response_model=self.response_model,
|
||||||
|
executor_context=self,
|
||||||
|
)
|
||||||
|
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||||
|
|
||||||
|
if isinstance(formatted_answer, AgentAction):
|
||||||
|
fingerprint_context = {}
|
||||||
|
if (
|
||||||
|
self.agent
|
||||||
|
and hasattr(self.agent, "security_config")
|
||||||
|
and hasattr(self.agent.security_config, "fingerprint")
|
||||||
|
):
|
||||||
|
fingerprint_context = {
|
||||||
|
"agent_fingerprint": str(
|
||||||
|
self.agent.security_config.fingerprint
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_result = await aexecute_tool_and_check_finality(
|
||||||
|
agent_action=formatted_answer,
|
||||||
|
fingerprint_context=fingerprint_context,
|
||||||
|
tools=self.tools,
|
||||||
|
i18n=self._i18n,
|
||||||
|
agent_key=self.agent.key if self.agent else None,
|
||||||
|
agent_role=self.agent.role if self.agent else None,
|
||||||
|
tools_handler=self.tools_handler,
|
||||||
|
task=self.task,
|
||||||
|
agent=self.agent,
|
||||||
|
function_calling_llm=self.function_calling_llm,
|
||||||
|
crew=self.crew,
|
||||||
|
)
|
||||||
|
formatted_answer = self._handle_agent_action(
|
||||||
|
formatted_answer, tool_result
|
||||||
|
)
|
||||||
|
|
||||||
|
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||||
|
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||||
|
|
||||||
|
except OutputParserError as e:
|
||||||
|
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||||
|
e=e,
|
||||||
|
messages=self.messages,
|
||||||
|
iterations=self.iterations,
|
||||||
|
log_error_after=self.log_error_after,
|
||||||
|
printer=self._printer,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
|
raise e
|
||||||
|
if is_context_length_exceeded(e):
|
||||||
|
handle_context_length(
|
||||||
|
respect_context_window=self.respect_context_window,
|
||||||
|
printer=self._printer,
|
||||||
|
messages=self.messages,
|
||||||
|
llm=self.llm,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
i18n=self._i18n,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
handle_unknown_error(self._printer, e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self.iterations += 1
|
||||||
|
|
||||||
|
if not isinstance(formatted_answer, AgentFinish):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Agent execution ended without reaching a final answer. "
|
||||||
|
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||||
|
)
|
||||||
|
self._show_logs(formatted_answer)
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
def _handle_agent_action(
|
def _handle_agent_action(
|
||||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentAction | AgentFinish:
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
def set_private_attrs(self) -> Crew:
|
def set_private_attrs(self) -> Crew:
|
||||||
"""set private attributes."""
|
"""set private attributes."""
|
||||||
self._cache_handler = CacheHandler()
|
self._cache_handler = CacheHandler()
|
||||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
event_listener = EventListener()
|
||||||
|
|
||||||
# Determine and set tracing state once for this execution
|
# Determine and set tracing state once for this execution
|
||||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||||
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _initialize_default_memories(self) -> None:
|
def _initialize_default_memories(self) -> None:
|
||||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||||
crew=self,
|
crew=self,
|
||||||
embedder_config=self.embedder,
|
embedder_config=self.embedder,
|
||||||
)
|
)
|
||||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
self._entity_memory = self.entity_memory or EntityMemory(
|
||||||
crew=self, embedder_config=self.embedder
|
crew=self, embedder_config=self.embedder
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -948,6 +948,342 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
self._task_output_handler.reset()
|
self._task_output_handler.reset()
|
||||||
return list(results)
|
return list(results)
|
||||||
|
|
||||||
|
async def akickoff(
|
||||||
|
self, inputs: dict[str, Any] | None = None
|
||||||
|
) -> CrewOutput | CrewStreamingOutput:
|
||||||
|
"""Native async kickoff method using async task execution throughout.
|
||||||
|
|
||||||
|
Unlike kickoff_async which wraps sync kickoff in a thread, this method
|
||||||
|
uses native async/await for all operations including task execution,
|
||||||
|
memory operations, and knowledge queries.
|
||||||
|
"""
|
||||||
|
if self.stream:
|
||||||
|
for agent in self.agents:
|
||||||
|
if agent.llm is not None:
|
||||||
|
agent.llm.stream = True
|
||||||
|
|
||||||
|
result_holder: list[CrewOutput] = []
|
||||||
|
current_task_info: TaskInfo = {
|
||||||
|
"index": 0,
|
||||||
|
"name": "",
|
||||||
|
"id": "",
|
||||||
|
"agent_role": "",
|
||||||
|
"agent_id": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
state = create_streaming_state(
|
||||||
|
current_task_info, result_holder, use_async=True
|
||||||
|
)
|
||||||
|
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||||
|
|
||||||
|
async def run_crew() -> None:
|
||||||
|
try:
|
||||||
|
self.stream = False
|
||||||
|
result = await self.akickoff(inputs)
|
||||||
|
if isinstance(result, CrewOutput):
|
||||||
|
result_holder.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
signal_error(state, e, is_async=True)
|
||||||
|
finally:
|
||||||
|
self.stream = True
|
||||||
|
signal_end(state, is_async=True)
|
||||||
|
|
||||||
|
streaming_output = CrewStreamingOutput(
|
||||||
|
async_iterator=create_async_chunk_generator(
|
||||||
|
state, run_crew, output_holder
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output_holder.append(streaming_output)
|
||||||
|
|
||||||
|
return streaming_output
|
||||||
|
|
||||||
|
ctx = baggage.set_baggage(
|
||||||
|
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||||
|
)
|
||||||
|
token = attach(ctx)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for before_callback in self.before_kickoff_callbacks:
|
||||||
|
if inputs is None:
|
||||||
|
inputs = {}
|
||||||
|
inputs = before_callback(inputs)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._task_output_handler.reset()
|
||||||
|
self._logging_color = "bold_purple"
|
||||||
|
|
||||||
|
if inputs is not None:
|
||||||
|
self._inputs = inputs
|
||||||
|
self._interpolate_inputs(inputs)
|
||||||
|
self._set_tasks_callbacks()
|
||||||
|
self._set_allow_crewai_trigger_context_for_first_task()
|
||||||
|
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.crew = self
|
||||||
|
agent.set_knowledge(crew_embedder=self.embedder)
|
||||||
|
if not agent.function_calling_llm: # type: ignore[attr-defined]
|
||||||
|
agent.function_calling_llm = self.function_calling_llm # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
if not agent.step_callback: # type: ignore[attr-defined]
|
||||||
|
agent.step_callback = self.step_callback # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
agent.create_agent_executor()
|
||||||
|
|
||||||
|
if self.planning:
|
||||||
|
self._handle_crew_planning()
|
||||||
|
|
||||||
|
if self.process == Process.sequential:
|
||||||
|
result = await self._arun_sequential_process()
|
||||||
|
elif self.process == Process.hierarchical:
|
||||||
|
result = await self._arun_hierarchical_process()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The process '{self.process}' is not implemented yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
for after_callback in self.after_kickoff_callbacks:
|
||||||
|
result = after_callback(result)
|
||||||
|
|
||||||
|
self.usage_metrics = self.calculate_usage_metrics()
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
detach(token)
|
||||||
|
|
||||||
|
async def akickoff_for_each(
|
||||||
|
self, inputs: list[dict[str, Any]]
|
||||||
|
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||||
|
"""Native async execution of the Crew's workflow for each input.
|
||||||
|
|
||||||
|
Uses native async throughout rather than thread-based async.
|
||||||
|
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||||
|
from all crews as they arrive.
|
||||||
|
"""
|
||||||
|
crew_copies = [self.copy() for _ in inputs]
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
result_holder: list[list[CrewOutput]] = [[]]
|
||||||
|
current_task_info: TaskInfo = {
|
||||||
|
"index": 0,
|
||||||
|
"name": "",
|
||||||
|
"id": "",
|
||||||
|
"agent_role": "",
|
||||||
|
"agent_id": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
state = create_streaming_state(
|
||||||
|
current_task_info, result_holder, use_async=True
|
||||||
|
)
|
||||||
|
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||||
|
|
||||||
|
async def run_all_crews() -> None:
|
||||||
|
try:
|
||||||
|
streaming_outputs: list[CrewStreamingOutput] = []
|
||||||
|
for i, crew in enumerate(crew_copies):
|
||||||
|
streaming = await crew.akickoff(inputs=inputs[i])
|
||||||
|
if isinstance(streaming, CrewStreamingOutput):
|
||||||
|
streaming_outputs.append(streaming)
|
||||||
|
|
||||||
|
async def consume_stream(
|
||||||
|
stream_output: CrewStreamingOutput,
|
||||||
|
) -> CrewOutput:
|
||||||
|
async for chunk in stream_output:
|
||||||
|
if state.async_queue is not None and state.loop is not None:
|
||||||
|
state.loop.call_soon_threadsafe(
|
||||||
|
state.async_queue.put_nowait, chunk
|
||||||
|
)
|
||||||
|
return stream_output.result
|
||||||
|
|
||||||
|
crew_results = await asyncio.gather(
|
||||||
|
*[consume_stream(s) for s in streaming_outputs]
|
||||||
|
)
|
||||||
|
result_holder[0] = list(crew_results)
|
||||||
|
except Exception as e:
|
||||||
|
signal_error(state, e, is_async=True)
|
||||||
|
finally:
|
||||||
|
signal_end(state, is_async=True)
|
||||||
|
|
||||||
|
streaming_output = CrewStreamingOutput(
|
||||||
|
async_iterator=create_async_chunk_generator(
|
||||||
|
state, run_all_crews, output_holder
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_results_wrapper(result: Any) -> None:
|
||||||
|
streaming_output._set_results(result)
|
||||||
|
|
||||||
|
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||||
|
output_holder.append(streaming_output)
|
||||||
|
|
||||||
|
return streaming_output
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(crew_copy.akickoff(inputs=input_data))
|
||||||
|
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
total_usage_metrics = UsageMetrics()
|
||||||
|
for crew_copy in crew_copies:
|
||||||
|
if crew_copy.usage_metrics:
|
||||||
|
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||||
|
self.usage_metrics = total_usage_metrics
|
||||||
|
|
||||||
|
self._task_output_handler.reset()
|
||||||
|
return list(results)
|
||||||
|
|
||||||
|
async def _arun_sequential_process(self) -> CrewOutput:
|
||||||
|
"""Executes tasks sequentially using native async and returns the final output."""
|
||||||
|
return await self._aexecute_tasks(self.tasks)
|
||||||
|
|
||||||
|
async def _arun_hierarchical_process(self) -> CrewOutput:
|
||||||
|
"""Creates and assigns a manager agent to complete the tasks using native async."""
|
||||||
|
self._create_manager_agent()
|
||||||
|
return await self._aexecute_tasks(self.tasks)
|
||||||
|
|
||||||
|
async def _aexecute_tasks(
|
||||||
|
self,
|
||||||
|
tasks: list[Task],
|
||||||
|
start_index: int | None = 0,
|
||||||
|
was_replayed: bool = False,
|
||||||
|
) -> CrewOutput:
|
||||||
|
"""Executes tasks using native async and returns the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: List of tasks to execute
|
||||||
|
start_index: Index to start execution from (for replay)
|
||||||
|
was_replayed: Whether this is a replayed execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CrewOutput: Final output of the crew
|
||||||
|
"""
|
||||||
|
task_outputs: list[TaskOutput] = []
|
||||||
|
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||||
|
last_sync_output: TaskOutput | None = None
|
||||||
|
|
||||||
|
for task_index, task in enumerate(tasks):
|
||||||
|
if start_index is not None and task_index < start_index:
|
||||||
|
if task.output:
|
||||||
|
if task.async_execution:
|
||||||
|
task_outputs.append(task.output)
|
||||||
|
else:
|
||||||
|
task_outputs = [task.output]
|
||||||
|
last_sync_output = task.output
|
||||||
|
continue
|
||||||
|
|
||||||
|
agent_to_use = self._get_agent_to_use(task)
|
||||||
|
if agent_to_use is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No agent available for task: {task.description}. "
|
||||||
|
f"Ensure that either the task has an assigned agent "
|
||||||
|
f"or a manager agent is provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
tools_for_task = task.tools or agent_to_use.tools or []
|
||||||
|
tools_for_task = self._prepare_tools(
|
||||||
|
agent_to_use,
|
||||||
|
task,
|
||||||
|
tools_for_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._log_task_start(task, agent_to_use.role)
|
||||||
|
|
||||||
|
if isinstance(task, ConditionalTask):
|
||||||
|
skipped_task_output = await self._ahandle_conditional_task(
|
||||||
|
task, task_outputs, pending_tasks, task_index, was_replayed
|
||||||
|
)
|
||||||
|
if skipped_task_output:
|
||||||
|
task_outputs.append(skipped_task_output)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if task.async_execution:
|
||||||
|
context = self._get_context(
|
||||||
|
task, [last_sync_output] if last_sync_output else []
|
||||||
|
)
|
||||||
|
async_task = asyncio.create_task(
|
||||||
|
task.aexecute_sync(
|
||||||
|
agent=agent_to_use,
|
||||||
|
context=context,
|
||||||
|
tools=tools_for_task,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pending_tasks.append((task, async_task, task_index))
|
||||||
|
else:
|
||||||
|
if pending_tasks:
|
||||||
|
task_outputs = await self._aprocess_async_tasks(
|
||||||
|
pending_tasks, was_replayed
|
||||||
|
)
|
||||||
|
pending_tasks.clear()
|
||||||
|
|
||||||
|
context = self._get_context(task, task_outputs)
|
||||||
|
task_output = await task.aexecute_sync(
|
||||||
|
agent=agent_to_use,
|
||||||
|
context=context,
|
||||||
|
tools=tools_for_task,
|
||||||
|
)
|
||||||
|
task_outputs.append(task_output)
|
||||||
|
self._process_task_result(task, task_output)
|
||||||
|
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||||
|
|
||||||
|
if pending_tasks:
|
||||||
|
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||||
|
|
||||||
|
return self._create_crew_output(task_outputs)
|
||||||
|
|
||||||
|
async def _ahandle_conditional_task(
|
||||||
|
self,
|
||||||
|
task: ConditionalTask,
|
||||||
|
task_outputs: list[TaskOutput],
|
||||||
|
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||||
|
task_index: int,
|
||||||
|
was_replayed: bool,
|
||||||
|
) -> TaskOutput | None:
|
||||||
|
"""Handle conditional task evaluation using native async."""
|
||||||
|
if pending_tasks:
|
||||||
|
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||||
|
pending_tasks.clear()
|
||||||
|
|
||||||
|
previous_output = task_outputs[-1] if task_outputs else None
|
||||||
|
if previous_output is not None and not task.should_execute(previous_output):
|
||||||
|
self._logger.log(
|
||||||
|
"debug",
|
||||||
|
f"Skipping conditional task: {task.description}",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
skipped_task_output = task.get_skipped_task_output()
|
||||||
|
|
||||||
|
if not was_replayed:
|
||||||
|
self._store_execution_log(task, skipped_task_output, task_index)
|
||||||
|
return skipped_task_output
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _aprocess_async_tasks(
|
||||||
|
self,
|
||||||
|
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||||
|
was_replayed: bool = False,
|
||||||
|
) -> list[TaskOutput]:
|
||||||
|
"""Process pending async tasks and return their outputs."""
|
||||||
|
task_outputs: list[TaskOutput] = []
|
||||||
|
for future_task, async_task, task_index in pending_tasks:
|
||||||
|
task_output = await async_task
|
||||||
|
task_outputs.append(task_output)
|
||||||
|
self._process_task_result(future_task, task_output)
|
||||||
|
self._store_execution_log(
|
||||||
|
future_task, task_output, task_index, was_replayed
|
||||||
|
)
|
||||||
|
return task_outputs
|
||||||
|
|
||||||
def _handle_crew_planning(self) -> None:
|
def _handle_crew_planning(self) -> None:
|
||||||
"""Handles the Crew planning."""
|
"""Handles the Crew planning."""
|
||||||
self._logger.log("info", "Planning the crew execution")
|
self._logger.log("info", "Planning the crew execution")
|
||||||
@@ -1431,6 +1767,16 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def aquery_knowledge(
|
||||||
|
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||||
|
) -> list[SearchResult] | None:
|
||||||
|
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||||
|
if self.knowledge:
|
||||||
|
return await self.knowledge.aquery(
|
||||||
|
query, results_limit=results_limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def fetch_inputs(self) -> set[str]:
|
def fetch_inputs(self) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||||
|
|||||||
@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
finally:
|
finally:
|
||||||
detach(flow_token)
|
detach(flow_token)
|
||||||
|
|
||||||
|
async def akickoff(
|
||||||
|
self, inputs: dict[str, Any] | None = None
|
||||||
|
) -> Any | FlowStreamingOutput:
|
||||||
|
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final output from the flow, which is the result of the last executed method.
|
||||||
|
"""
|
||||||
|
return await self.kickoff_async(inputs)
|
||||||
|
|
||||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||||
"""Executes a flow's start method and its triggered listeners.
|
"""Executes a flow's start method and its triggered listeners.
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
|||||||
sources: list[BaseKnowledgeSource],
|
sources: list[BaseKnowledgeSource],
|
||||||
embedder: EmbedderConfig | None = None,
|
embedder: EmbedderConfig | None = None,
|
||||||
storage: KnowledgeStorage | None = None,
|
storage: KnowledgeStorage | None = None,
|
||||||
**data,
|
**data: object,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
if storage:
|
if storage:
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
|||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Storage is not initialized.")
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|
||||||
|
async def aquery(
|
||||||
|
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Query across all knowledge sources asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: List of query strings.
|
||||||
|
results_limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The top results matching the query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If storage is not initialized.
|
||||||
|
"""
|
||||||
|
if self.storage is None:
|
||||||
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|
||||||
|
return await self.storage.asearch(
|
||||||
|
query,
|
||||||
|
limit=results_limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aadd_sources(self) -> None:
|
||||||
|
"""Add all knowledge sources to storage asynchronously."""
|
||||||
|
try:
|
||||||
|
for source in self.sources:
|
||||||
|
source.storage = self.storage
|
||||||
|
await source.aadd()
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.areset()
|
||||||
|
else:
|
||||||
|
raise ValueError("Storage is not initialized.")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info): # noqa: N805
|
@classmethod
|
||||||
|
def validate_file_path(
|
||||||
|
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||||
|
) -> Path | list[Path] | str | list[str] | None:
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
raise ValueError("Either file_path or file_paths must be provided")
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _: Any) -> None:
|
||||||
"""Post-initialization method to load content."""
|
"""Post-initialization method to load content."""
|
||||||
self.safe_file_paths = self._process_file_paths()
|
self.safe_file_paths = self._process_file_paths()
|
||||||
self.validate_content()
|
self.validate_content()
|
||||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
def load_content(self) -> dict[Path, str]:
|
def load_content(self) -> dict[Path, str]:
|
||||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self) -> None:
|
||||||
"""Save the documents to the storage."""
|
"""Save the documents to the storage."""
|
||||||
if self.storage:
|
if self.storage:
|
||||||
self.storage.save(self.chunks)
|
self.storage.save(self.chunks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
|
async def _asave_documents(self) -> None:
|
||||||
|
"""Save the documents to the storage asynchronously."""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.asave(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
def convert_to_path(self, path: Path | str) -> Path:
|
def convert_to_path(self, path: Path | str) -> Path:
|
||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|||||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self) -> None:
|
||||||
"""
|
"""Save the documents to the storage.
|
||||||
Save the documents to the storage.
|
|
||||||
This method should be called after the chunks and embeddings are generated.
|
This method should be called after the chunks and embeddings are generated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no storage is configured.
|
||||||
"""
|
"""
|
||||||
if self.storage:
|
if self.storage:
|
||||||
self.storage.save(self.chunks)
|
self.storage.save(self.chunks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||||
|
|
||||||
|
async def _asave_documents(self) -> None:
|
||||||
|
"""Save the documents to the storage asynchronously.
|
||||||
|
|
||||||
|
This method should be called after the chunks and embeddings are generated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no storage is configured.
|
||||||
|
"""
|
||||||
|
if self.storage:
|
||||||
|
await self.storage.asave(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|||||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
from docling.datamodel.base_models import InputFormat
|
||||||
InputFormat,
|
from docling.document_converter import DocumentConverter
|
||||||
)
|
from docling.exceptions import ConversionError
|
||||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||||
DocumentConverter,
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
)
|
|
||||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
|
||||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
|
||||||
HierarchicalChunker,
|
|
||||||
)
|
|
||||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
|
||||||
DoclingDocument,
|
|
||||||
)
|
|
||||||
|
|
||||||
DOCLING_AVAILABLE = True
|
DOCLING_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
DOCLING_AVAILABLE = False
|
DOCLING_AVAILABLE = False
|
||||||
|
# Provide type stubs for when docling is not available
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
|||||||
|
|
||||||
|
|
||||||
class CrewDoclingSource(BaseKnowledgeSource):
|
class CrewDoclingSource(BaseKnowledgeSource):
|
||||||
"""Default Source class for converting documents to markdown or json
|
"""Default Source class for converting documents to markdown or json.
|
||||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
|
||||||
|
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||||
|
any additional dependencies and follows the docling package as the source of truth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
if not DOCLING_AVAILABLE:
|
if not DOCLING_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The docling package is required to use CrewDoclingSource. "
|
"The docling package is required to use CrewDoclingSource. "
|
||||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, _) -> None:
|
def model_post_init(self, _: Any) -> None:
|
||||||
if self.file_path:
|
if self.file_path:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
"warning",
|
"warning",
|
||||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(list(new_chunks_iterable))
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add docling content asynchronously."""
|
||||||
|
if self.content is None:
|
||||||
|
return
|
||||||
|
for doc in self.content:
|
||||||
|
new_chunks_iterable = self._chunk_doc(doc)
|
||||||
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||||
return [result.document for result in conv_results_iter]
|
return [result.document for result in conv_results_iter]
|
||||||
|
|||||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add CSV file content asynchronously."""
|
||||||
|
content_str = (
|
||||||
|
str(self.content) if isinstance(self.content, dict) else self.content
|
||||||
|
)
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info): # noqa: N805
|
@classmethod
|
||||||
|
def validate_file_path(
|
||||||
|
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||||
|
) -> Path | list[Path] | str | list[str] | None:
|
||||||
"""Validate that at least one of file_path or file_paths is provided."""
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
# Single check if both are None, O(1) instead of nested conditions
|
# Single check if both are None, O(1) instead of nested conditions
|
||||||
if (
|
if (
|
||||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
return [self.convert_to_path(path) for path in path_list]
|
return [self.convert_to_path(path) for path in path_list]
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, _) -> None:
|
def model_post_init(self, _: Any) -> None:
|
||||||
if self.file_path:
|
if self.file_path:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
"warning",
|
"warning",
|
||||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||||
|
|
||||||
def _import_dependencies(self):
|
def _import_dependencies(self) -> ModuleType:
|
||||||
"""Dynamically import dependencies."""
|
"""Dynamically import dependencies."""
|
||||||
try:
|
try:
|
||||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
import pandas as pd # type: ignore[import-untyped]
|
||||||
|
|
||||||
return pd
|
return pd # type: ignore[no-any-return]
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
missing_package = str(e).split()[-1]
|
missing_package = str(e).split()[-1]
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add Excel file content asynchronously."""
|
||||||
|
content_str = ""
|
||||||
|
for value in self.content.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for sheet_value in value.values():
|
||||||
|
content_str += str(sheet_value) + "\n"
|
||||||
|
else:
|
||||||
|
content_str += str(value) + "\n"
|
||||||
|
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add JSON file content asynchronously."""
|
||||||
|
content_str = (
|
||||||
|
str(self.content) if isinstance(self.content, dict) else self.content
|
||||||
|
)
|
||||||
|
new_chunks = self._chunk_text(content_str)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
content[path] = text
|
content[path] = text
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def _import_pdfplumber(self):
|
def _import_pdfplumber(self) -> ModuleType:
|
||||||
"""Dynamically import pdfplumber."""
|
"""Dynamically import pdfplumber."""
|
||||||
try:
|
try:
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add PDF file content asynchronously."""
|
||||||
|
for text in self.content.values():
|
||||||
|
new_chunks = self._chunk_text(text)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
content: str = Field(...)
|
content: str = Field(...)
|
||||||
collection_name: str | None = Field(default=None)
|
collection_name: str | None = Field(default=None)
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _: Any) -> None:
|
||||||
"""Post-initialization method to validate content."""
|
"""Post-initialization method to validate content."""
|
||||||
self.validate_content()
|
self.validate_content()
|
||||||
|
|
||||||
def validate_content(self):
|
def validate_content(self) -> None:
|
||||||
"""Validate string content."""
|
"""Validate string content."""
|
||||||
if not isinstance(self.content, str):
|
if not isinstance(self.content, str):
|
||||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add string content asynchronously."""
|
||||||
|
new_chunks = self._chunk_text(self.content)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
self._save_documents()
|
self._save_documents()
|
||||||
|
|
||||||
|
async def aadd(self) -> None:
|
||||||
|
"""Add text file content asynchronously."""
|
||||||
|
for text in self.content.values():
|
||||||
|
new_chunks = self._chunk_text(text)
|
||||||
|
self.chunks.extend(new_chunks)
|
||||||
|
await self._asave_documents()
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> list[str]:
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
"""Utility method to split text into chunks."""
|
"""Utility method to split text into chunks."""
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
|||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search for documents in the knowledge base."""
|
"""Search for documents in the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: list[str],
|
||||||
|
limit: int = 5,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for documents in the knowledge base asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, documents: list[str]) -> None:
|
def save(self, documents: list[str]) -> None:
|
||||||
"""Save documents to the knowledge base."""
|
"""Save documents to the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def asave(self, documents: list[str]) -> None:
|
||||||
|
"""Save documents to the knowledge base asynchronously."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset the knowledge base."""
|
"""Reset the knowledge base."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedder: ProviderSpec
|
embedder: ProviderSpec
|
||||||
| BaseEmbeddingsProvider
|
| BaseEmbeddingsProvider[Any]
|
||||||
| type[BaseEmbeddingsProvider]
|
| type[BaseEmbeddingsProvider[Any]]
|
||||||
| None = None,
|
| None = None,
|
||||||
collection_name: str | None = None,
|
collection_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
) from e
|
) from e
|
||||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: list[str],
|
||||||
|
limit: int = 5,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for documents in the knowledge base asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: List of query strings.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
metadata_filter: Optional metadata filter for the search.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("Query cannot be empty")
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||||
|
|
||||||
|
return await client.asearch(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query_text,
|
||||||
|
limit=limit,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def asave(self, documents: list[str]) -> None:
|
||||||
|
"""Save documents to the knowledge base asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of document strings to save.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
await client.aget_or_create_collection(collection_name=collection_name)
|
||||||
|
|
||||||
|
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||||
|
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name, documents=rag_documents
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "dimension mismatch" in str(e).lower():
|
||||||
|
Logger(verbose=True).log(
|
||||||
|
"error",
|
||||||
|
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||||
|
"across all operations with this collection."
|
||||||
|
"Try resetting the collection using `crewai reset-memories -a`"
|
||||||
|
) from e
|
||||||
|
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the knowledge base asynchronously."""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"knowledge_{self.collection_name}"
|
||||||
|
if self.collection_name
|
||||||
|
else "knowledge"
|
||||||
|
)
|
||||||
|
await client.adelete_collection(collection_name=collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from crewai.memory import (
|
from crewai.memory import (
|
||||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class ContextualMemory:
|
class ContextualMemory:
|
||||||
|
"""Aggregates and retrieves context from multiple memory sources."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stm: ShortTermMemory,
|
stm: ShortTermMemory,
|
||||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
|||||||
self.exm.task = self.task
|
self.exm.task = self.task
|
||||||
|
|
||||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||||
"""
|
"""Build contextual information for a task synchronously.
|
||||||
Automatically builds a minimal, highly relevant set of contextual information
|
|
||||||
for a given task.
|
Args:
|
||||||
|
task: The task to build context for.
|
||||||
|
context: Additional context string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string from all memory sources.
|
||||||
"""
|
"""
|
||||||
query = f"{task.description} {context}".strip()
|
query = f"{task.description} {context}".strip()
|
||||||
|
|
||||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
|||||||
]
|
]
|
||||||
return "\n".join(filter(None, context_parts))
|
return "\n".join(filter(None, context_parts))
|
||||||
|
|
||||||
|
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||||
|
"""Build contextual information for a task asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to build context for.
|
||||||
|
context: Additional context string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string from all memory sources.
|
||||||
|
"""
|
||||||
|
query = f"{task.description} {context}".strip()
|
||||||
|
|
||||||
|
if query == "":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Fetch all contexts concurrently
|
||||||
|
results = await asyncio.gather(
|
||||||
|
self._afetch_ltm_context(task.description),
|
||||||
|
self._afetch_stm_context(query),
|
||||||
|
self._afetch_entity_context(query),
|
||||||
|
self._afetch_external_context(query),
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(filter(None, results))
|
||||||
|
|
||||||
def _fetch_stm_context(self, query: str) -> str:
|
def _fetch_stm_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
|||||||
f"- {result['content']}" for result in external_memories
|
f"- {result['content']}" for result in external_memories
|
||||||
)
|
)
|
||||||
return f"External memories:\n{formatted_memories}"
|
return f"External memories:\n{formatted_memories}"
|
||||||
|
|
||||||
|
async def _afetch_stm_context(self, query: str) -> str:
|
||||||
|
"""Fetch recent relevant insights from STM asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted insights as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.stm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
stm_results = await self.stm.asearch(query)
|
||||||
|
formatted_results = "\n".join(
|
||||||
|
[f"- {result['content']}" for result in stm_results]
|
||||||
|
)
|
||||||
|
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||||
|
|
||||||
|
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||||
|
"""Fetch historical data from LTM asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted historical data as bullet points, or None if none found.
|
||||||
|
"""
|
||||||
|
if self.ltm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||||
|
if not ltm_results:
|
||||||
|
return None
|
||||||
|
|
||||||
|
formatted_results = [
|
||||||
|
suggestion
|
||||||
|
for result in ltm_results
|
||||||
|
for suggestion in result["metadata"]["suggestions"]
|
||||||
|
]
|
||||||
|
formatted_results = list(dict.fromkeys(formatted_results))
|
||||||
|
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||||
|
|
||||||
|
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||||
|
|
||||||
|
async def _afetch_entity_context(self, query: str) -> str:
|
||||||
|
"""Fetch relevant entity information asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted entity information as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.em is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
em_results = await self.em.asearch(query)
|
||||||
|
formatted_results = "\n".join(
|
||||||
|
[f"- {result['content']}" for result in em_results]
|
||||||
|
)
|
||||||
|
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||||
|
|
||||||
|
async def _afetch_external_context(self, query: str) -> str:
|
||||||
|
"""Fetch relevant information from External Memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted information as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.exm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
external_memories = await self.exm.asearch(query)
|
||||||
|
|
||||||
|
if not external_memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_memories = "\n".join(
|
||||||
|
f"- {result['content']}" for result in external_memories
|
||||||
|
)
|
||||||
|
return f"External memories:\n{formatted_memories}"
|
||||||
|
|||||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
|||||||
|
|
||||||
_memory_provider: str | None = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
crew: Any = None,
|
||||||
|
embedder_config: Any = None,
|
||||||
|
storage: Any = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
memory_provider = None
|
memory_provider = None
|
||||||
if embedder_config and isinstance(embedder_config, dict):
|
if embedder_config and isinstance(embedder_config, dict):
|
||||||
memory_provider = embedder_config.get("provider")
|
memory_provider = embedder_config.get("provider")
|
||||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
|||||||
if embedder_config and isinstance(embedder_config, dict)
|
if embedder_config and isinstance(embedder_config, dict)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
else:
|
else:
|
||||||
storage = (
|
storage = (
|
||||||
storage
|
storage
|
||||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search entity memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save entity items asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||||
|
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||||
|
"""
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
|
||||||
|
items = value if isinstance(value, list) else [value]
|
||||||
|
is_batch = len(items) > 1
|
||||||
|
|
||||||
|
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
saved_count = 0
|
||||||
|
errors: list[str | None] = []
|
||||||
|
|
||||||
|
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||||
|
"""Save a single item asynchronously."""
|
||||||
|
try:
|
||||||
|
if self._memory_provider == "mem0":
|
||||||
|
data = f"""
|
||||||
|
Remember details about the following entity:
|
||||||
|
Name: {item.name}
|
||||||
|
Type: {item.type}
|
||||||
|
Entity Description: {item.description}
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
|
|
||||||
|
await super(EntityMemory, self).asave(data, item.metadata)
|
||||||
|
return True, None
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"{item.name}: {e!s}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in items:
|
||||||
|
success, error = await save_single_item(item)
|
||||||
|
if success:
|
||||||
|
saved_count += 1
|
||||||
|
else:
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
emit_value = f"Saved {saved_count} entities"
|
||||||
|
metadata = {"entity_count": saved_count, "errors": errors}
|
||||||
|
else:
|
||||||
|
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||||
|
metadata = items[0].metadata
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=emit_value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise Exception(
|
||||||
|
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
fail_metadata = (
|
||||||
|
{"entity_count": len(items), "saved": saved_count}
|
||||||
|
if is_batch
|
||||||
|
else items[0].metadata
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
metadata=fail_metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search entity memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await super().asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="entity_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
|||||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
|
|
||||||
return Mem0Storage(type="external", crew=crew, config=config)
|
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def external_supported_storages() -> dict[str, Any]:
|
def external_supported_storages() -> dict[str, Any]:
|
||||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
|||||||
if provider not in supported_storages:
|
if provider not in supported_storages:
|
||||||
raise ValueError(f"Provider {provider} not supported")
|
raise ValueError(f"Provider {provider} not supported")
|
||||||
|
|
||||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
storage: Storage = supported_storages[provider](
|
||||||
|
crew, embedder_config.get("config", {})
|
||||||
|
)
|
||||||
|
return storage
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search external memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to external memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
item = ExternalMemoryItem(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
agent=self.agent.role if self.agent else None,
|
||||||
|
)
|
||||||
|
await super().asave(value=item.value, metadata=item.metadata)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search external memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await super().asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="external_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
|||||||
LongTermMemoryItem instances.
|
LongTermMemoryItem instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: LTMSQLiteStorage | None = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||||
super().__init__(storage=storage)
|
super().__init__(storage=storage)
|
||||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
|||||||
metadata.update(
|
metadata.update(
|
||||||
{"agent": item.agent, "expected_output": item.expected_output}
|
{"agent": item.agent, "expected_output": item.expected_output}
|
||||||
)
|
)
|
||||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
self.storage.save(
|
||||||
task_description=item.task,
|
task_description=item.task,
|
||||||
score=metadata["quality"],
|
score=metadata["quality"],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
def search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
latest_n: int = 3,
|
latest_n: int = 3,
|
||||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search long-term memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
results = self.storage.load(task, latest_n)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results or []
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=task,
|
||||||
|
limit=latest_n,
|
||||||
|
error=str(e),
|
||||||
|
source_type="long_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||||
|
"""Save an item to long-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The LongTermMemoryItem to save.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
metadata = item.metadata
|
||||||
|
metadata.update(
|
||||||
|
{"agent": item.agent, "expected_output": item.expected_output}
|
||||||
|
)
|
||||||
|
await self.storage.asave(
|
||||||
|
task_description=item.task,
|
||||||
|
score=metadata["quality"],
|
||||||
|
metadata=metadata,
|
||||||
|
datetime=item.datetime,
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
error=str(e),
|
||||||
|
source_type="long_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
latest_n: int = 3,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search long-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=task,
|
||||||
|
limit=latest_n,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await self.storage.aload(task, latest_n)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=task,
|
||||||
|
results=results,
|
||||||
|
limit=latest_n,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results or []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
"""Reset long-term memory."""
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel):
|
class Memory(BaseModel):
|
||||||
"""
|
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||||
Base class for memory, now supporting agent tags and generic metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||||
crew: Any | None = None
|
crew: Any | None = None
|
||||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
|||||||
value: Any,
|
value: Any,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
metadata = metadata or {}
|
"""Save a value to memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
self.storage.save(value, metadata)
|
self.storage.save(value, metadata)
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
|
await self.storage.asave(value, metadata)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
return self.storage.search(
|
"""Search memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
results: list[Any] = self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
)
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search memory for relevant entries asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
results: list[Any] = await self.storage.asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
def set_crew(self, crew: Any) -> Memory:
|
def set_crew(self, crew: Any) -> Memory:
|
||||||
|
"""Set the crew for this memory instance."""
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
|||||||
|
|
||||||
_memory_provider: str | None = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
crew: Any = None,
|
||||||
|
embedder_config: Any = None,
|
||||||
|
storage: Any = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
memory_provider = None
|
memory_provider = None
|
||||||
if embedder_config and isinstance(embedder_config, dict):
|
if embedder_config and isinstance(embedder_config, dict):
|
||||||
memory_provider = embedder_config.get("provider")
|
memory_provider = embedder_config.get("provider")
|
||||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
|||||||
if embedder_config and isinstance(embedder_config, dict)
|
if embedder_config and isinstance(embedder_config, dict)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
else:
|
else:
|
||||||
storage = (
|
storage = (
|
||||||
storage
|
storage
|
||||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search short-term memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
|||||||
try:
|
try:
|
||||||
results = self.storage.search(
|
results = self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return list(results)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="short_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to short-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
item = ShortTermMemoryItem(
|
||||||
|
data=value,
|
||||||
|
metadata=metadata,
|
||||||
|
agent=self.agent.role if self.agent else None,
|
||||||
|
)
|
||||||
|
if self._memory_provider == "mem0":
|
||||||
|
item.data = (
|
||||||
|
f"Remember the following insights from Agent run: {item.data}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await super().asave(value=item.data, metadata=item.metadata)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search short-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await self.storage.asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(results)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
class LTMSQLiteStorage:
|
class LTMSQLiteStorage:
|
||||||
"""
|
"""SQLite storage class for long-term memory data."""
|
||||||
An updated SQLite storage class for LTM data storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db_path: str | None = None) -> None:
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
|
"""Initialize the SQLite storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Optional path to the database file.
|
||||||
|
"""
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
|
||||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
# Ensure parent directory exists
|
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
def _initialize_db(self):
|
def _initialize_db(self) -> None:
|
||||||
"""
|
"""Initialize the SQLite database and create LTM table."""
|
||||||
Initializes the SQLite database and creates LTM table
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def reset(
|
def reset(self) -> None:
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""Resets the LTM table with error handling."""
|
"""Resets the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
|||||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
task_description: str,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
datetime: str,
|
||||||
|
score: int | float,
|
||||||
|
) -> None:
|
||||||
|
"""Save data to the LTM table asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: Description of the task.
|
||||||
|
metadata: Metadata associated with the memory.
|
||||||
|
datetime: Timestamp of the memory.
|
||||||
|
score: Quality score of the memory.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(task_description, json.dumps(metadata), datetime, score),
|
||||||
|
)
|
||||||
|
await conn.commit()
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aload(
|
||||||
|
self, task_description: str, latest_n: int
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Query the LTM table by task description asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: Description of the task to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries or None if error occurs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT metadata, datetime, score
|
||||||
|
FROM long_term_memories
|
||||||
|
WHERE task_description = ?
|
||||||
|
ORDER BY datetime DESC, score ASC
|
||||||
|
LIMIT {latest_n}
|
||||||
|
""", # nosec # noqa: S608
|
||||||
|
(task_description,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
if rows:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"metadata": json.loads(row[0]),
|
||||||
|
"datetime": row[1],
|
||||||
|
"score": row[2],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the LTM table asynchronously."""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
await conn.execute("DELETE FROM long_term_memories")
|
||||||
|
await conn.commit()
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
return f"{base_path}/{file_name}"
|
return f"{base_path}/{file_name}"
|
||||||
|
|
||||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
|
"""Save a value to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Metadata to associate with the value.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
collection_name = (
|
collection_name = (
|
||||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
|
"""Save a value to storage asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"memory_{self.type}_{self.agents}"
|
||||||
|
if self.agents
|
||||||
|
else f"memory_{self.type}"
|
||||||
|
)
|
||||||
|
await client.aget_or_create_collection(collection_name=collection_name)
|
||||||
|
|
||||||
|
document: BaseRecord = {"content": value}
|
||||||
|
if metadata:
|
||||||
|
document["metadata"] = metadata
|
||||||
|
|
||||||
|
batch_size = None
|
||||||
|
if (
|
||||||
|
self.embedder_config
|
||||||
|
and isinstance(self.embedder_config, dict)
|
||||||
|
and "config" in self.embedder_config
|
||||||
|
):
|
||||||
|
nested_config = self.embedder_config["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name,
|
||||||
|
documents=[document],
|
||||||
|
batch_size=cast(int, batch_size),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name, documents=[document]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
filter: dict[str, Any] | None = None,
|
filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
|
"""Search for matching entries in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
filter: Optional metadata filter.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
collection_name = (
|
collection_name = (
|
||||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search for matching entries in storage asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
filter: Optional metadata filter.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"memory_{self.type}_{self.agents}"
|
||||||
|
if self.agents
|
||||||
|
else f"memory_{self.type}"
|
||||||
|
)
|
||||||
|
return await client.asearch(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
metadata_filter=filter,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|||||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
|||||||
result = self._execute_core(agent, context, tools)
|
result = self._execute_core(agent, context, tools)
|
||||||
future.set_result(result)
|
future.set_result(result)
|
||||||
|
|
||||||
|
async def aexecute_sync(
|
||||||
|
self,
|
||||||
|
agent: BaseAgent | None = None,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Execute the task asynchronously using native async/await."""
|
||||||
|
return await self._aexecute_core(agent, context, tools)
|
||||||
|
|
||||||
|
async def _aexecute_core(
|
||||||
|
self,
|
||||||
|
agent: BaseAgent | None,
|
||||||
|
context: str | None,
|
||||||
|
tools: list[Any] | None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Run the core execution logic of the task asynchronously."""
|
||||||
|
try:
|
||||||
|
agent = agent or self.agent
|
||||||
|
self.agent = agent
|
||||||
|
if not agent:
|
||||||
|
raise Exception(
|
||||||
|
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
self.prompt_context = context
|
||||||
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
|
self.processed_by_agents.add(agent.role)
|
||||||
|
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||||
|
result = await agent.aexecute_task(
|
||||||
|
task=self,
|
||||||
|
context=context,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._guardrails and not self._guardrail:
|
||||||
|
pydantic_output, json_output = self._export_output(result)
|
||||||
|
else:
|
||||||
|
pydantic_output, json_output = None, None
|
||||||
|
|
||||||
|
task_output = TaskOutput(
|
||||||
|
name=self.name or self.description,
|
||||||
|
description=self.description,
|
||||||
|
expected_output=self.expected_output,
|
||||||
|
raw=result,
|
||||||
|
pydantic=pydantic_output,
|
||||||
|
json_dict=json_output,
|
||||||
|
agent=agent.role,
|
||||||
|
output_format=self._get_output_format(),
|
||||||
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._guardrails:
|
||||||
|
for idx, guardrail in enumerate(self._guardrails):
|
||||||
|
task_output = await self._ainvoke_guardrail_function(
|
||||||
|
task_output=task_output,
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
guardrail=guardrail,
|
||||||
|
guardrail_index=idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._guardrail:
|
||||||
|
task_output = await self._ainvoke_guardrail_function(
|
||||||
|
task_output=task_output,
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
guardrail=self._guardrail,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output = task_output
|
||||||
|
self.end_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
if self.callback:
|
||||||
|
self.callback(self.output)
|
||||||
|
|
||||||
|
crew = self.agent.crew # type: ignore[union-attr]
|
||||||
|
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||||
|
crew.task_callback(self.output)
|
||||||
|
|
||||||
|
if self.output_file:
|
||||||
|
content = (
|
||||||
|
json_output
|
||||||
|
if json_output
|
||||||
|
else (
|
||||||
|
pydantic_output.model_dump_json() if pydantic_output else result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._save_file(content)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||||
|
)
|
||||||
|
return task_output
|
||||||
|
except Exception as e:
|
||||||
|
self.end_time = datetime.datetime.now()
|
||||||
|
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||||
|
raise e # Re-raise the exception after emitting the event
|
||||||
|
|
||||||
def _execute_core(
|
def _execute_core(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent | None,
|
agent: BaseAgent | None,
|
||||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
|||||||
json_dict=json_output,
|
json_dict=json_output,
|
||||||
agent=agent.role,
|
agent=agent.role,
|
||||||
output_format=self._get_output_format(),
|
output_format=self._get_output_format(),
|
||||||
messages=agent.last_messages,
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._guardrails:
|
if self._guardrails:
|
||||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
|||||||
json_dict=json_output,
|
json_dict=json_output,
|
||||||
agent=agent.role,
|
agent=agent.role,
|
||||||
output_format=self._get_output_format(),
|
output_format=self._get_output_format(),
|
||||||
messages=agent.last_messages,
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
async def _ainvoke_guardrail_function(
|
||||||
|
self,
|
||||||
|
task_output: TaskOutput,
|
||||||
|
agent: BaseAgent,
|
||||||
|
tools: list[BaseTool],
|
||||||
|
guardrail: GuardrailCallable | None,
|
||||||
|
guardrail_index: int | None = None,
|
||||||
|
) -> TaskOutput:
|
||||||
|
"""Invoke the guardrail function asynchronously."""
|
||||||
|
if not guardrail:
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
if guardrail_index is not None:
|
||||||
|
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||||
|
else:
|
||||||
|
current_retry_count = self.retry_count
|
||||||
|
|
||||||
|
max_attempts = self.guardrail_max_retries + 1
|
||||||
|
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
guardrail_result = process_guardrail(
|
||||||
|
output=task_output,
|
||||||
|
guardrail=guardrail,
|
||||||
|
retry_count=current_retry_count,
|
||||||
|
event_source=self,
|
||||||
|
from_task=self,
|
||||||
|
from_agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if guardrail_result.success:
|
||||||
|
if guardrail_result.result is None:
|
||||||
|
raise Exception(
|
||||||
|
"Task guardrail returned None as result. This is not allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(guardrail_result.result, str):
|
||||||
|
task_output.raw = guardrail_result.result
|
||||||
|
pydantic_output, json_output = self._export_output(
|
||||||
|
guardrail_result.result
|
||||||
|
)
|
||||||
|
task_output.pydantic = pydantic_output
|
||||||
|
task_output.json_dict = json_output
|
||||||
|
elif isinstance(guardrail_result.result, TaskOutput):
|
||||||
|
task_output = guardrail_result.result
|
||||||
|
|
||||||
|
return task_output
|
||||||
|
|
||||||
|
if attempt >= self.guardrail_max_retries:
|
||||||
|
guardrail_name = (
|
||||||
|
f"guardrail {guardrail_index}"
|
||||||
|
if guardrail_index is not None
|
||||||
|
else "guardrail"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||||
|
f"Last error: {guardrail_result.error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if guardrail_index is not None:
|
||||||
|
current_retry_count += 1
|
||||||
|
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||||
|
else:
|
||||||
|
self.retry_count += 1
|
||||||
|
current_retry_count = self.retry_count
|
||||||
|
|
||||||
|
context = self.i18n.errors("validation_error").format(
|
||||||
|
guardrail_result_error=guardrail_result.error,
|
||||||
|
task_output=task_output.raw,
|
||||||
|
)
|
||||||
|
printer = Printer()
|
||||||
|
printer.print(
|
||||||
|
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.aexecute_task(
|
||||||
|
task=self,
|
||||||
|
context=context,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
pydantic_output, json_output = self._export_output(result)
|
||||||
|
task_output = TaskOutput(
|
||||||
|
name=self.name or self.description,
|
||||||
|
description=self.description,
|
||||||
|
expected_output=self.expected_output,
|
||||||
|
raw=result,
|
||||||
|
pydantic=pydantic_output,
|
||||||
|
json_dict=json_output,
|
||||||
|
agent=agent.role,
|
||||||
|
output_format=self._get_output_format(),
|
||||||
|
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
return task_output
|
return task_output
|
||||||
|
|||||||
@@ -242,17 +242,17 @@ def get_llm_response(
|
|||||||
"""Call the LLM and return the response, handling any invalid responses.
|
"""Call the LLM and return the response, handling any invalid responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: The LLM instance to call
|
llm: The LLM instance to call.
|
||||||
messages: The messages to send to the LLM
|
messages: The messages to send to the LLM.
|
||||||
callbacks: List of callbacks for the LLM call
|
callbacks: List of callbacks for the LLM call.
|
||||||
printer: Printer instance for output
|
printer: Printer instance for output.
|
||||||
from_task: Optional task context for the LLM call
|
from_task: Optional task context for the LLM call.
|
||||||
from_agent: Optional agent context for the LLM call
|
from_agent: Optional agent context for the LLM call.
|
||||||
response_model: Optional Pydantic model for structured outputs
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
executor_context: Optional executor context for hook invocation
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the LLM as a string
|
The response from the LLM as a string.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs.
|
Exception: If an error occurs.
|
||||||
@@ -284,6 +284,60 @@ def get_llm_response(
|
|||||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_llm_response(
|
||||||
|
llm: LLM | BaseLLM,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
callbacks: list[TokenCalcHandler],
|
||||||
|
printer: Printer,
|
||||||
|
from_task: Task | None = None,
|
||||||
|
from_agent: Agent | LiteAgent | None = None,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
executor_context: CrewAgentExecutor | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Call the LLM asynchronously and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: The LLM instance to call.
|
||||||
|
messages: The messages to send to the LLM.
|
||||||
|
callbacks: List of callbacks for the LLM call.
|
||||||
|
printer: Printer instance for output.
|
||||||
|
from_task: Optional task context for the LLM call.
|
||||||
|
from_agent: Optional agent context for the LLM call.
|
||||||
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the LLM as a string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs.
|
||||||
|
ValueError: If the response is None or empty.
|
||||||
|
"""
|
||||||
|
if executor_context is not None:
|
||||||
|
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||||
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
messages = executor_context.messages
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = await llm.acall(
|
||||||
|
messages,
|
||||||
|
callbacks=callbacks,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent, # type: ignore[arg-type]
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
if not answer:
|
||||||
|
printer.print(
|
||||||
|
content="Received None or empty response from LLM call.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
|
||||||
|
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
answer: str, use_stop_words: bool
|
answer: str, use_stop_words: bool
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentAction | AgentFinish:
|
||||||
|
|||||||
@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
|||||||
# Dummy implementation for MCP tools
|
# Dummy implementation for MCP tools
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
# Dummy async implementation
|
||||||
|
return "Task executed"
|
||||||
|
|
||||||
|
|
||||||
def test_base_agent_adapter_initialization():
|
def test_base_agent_adapter_initialization():
|
||||||
"""Test initialization of the concrete agent adapter."""
|
"""Test initialization of the concrete agent adapter."""
|
||||||
|
|||||||
@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
|
|||||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def aexecute_task(
|
||||||
|
self,
|
||||||
|
task: Any,
|
||||||
|
context: str | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
def get_output_converter(
|
def get_output_converter(
|
||||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||||
): ...
|
): ...
|
||||||
|
|||||||
345
lib/crewai/tests/agents/test_async_agent_executor.py
Normal file
345
lib/crewai/tests/agents/test_async_agent_executor.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Tests for async agent executor functionality."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
|
from crewai.agents.parser import AgentAction, AgentFinish
|
||||||
|
from crewai.tools.tool_types import ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm() -> MagicMock:
|
||||||
|
"""Create a mock LLM for testing."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.supports_stop_words.return_value = True
|
||||||
|
llm.stop = []
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent() -> MagicMock:
|
||||||
|
"""Create a mock agent for testing."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.role = "Test Agent"
|
||||||
|
agent.key = "test_agent_key"
|
||||||
|
agent.verbose = False
|
||||||
|
agent.id = "test_agent_id"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task() -> MagicMock:
|
||||||
|
"""Create a mock task for testing."""
|
||||||
|
task = MagicMock()
|
||||||
|
task.description = "Test task description"
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_crew() -> MagicMock:
|
||||||
|
"""Create a mock crew for testing."""
|
||||||
|
crew = MagicMock()
|
||||||
|
crew.verbose = False
|
||||||
|
crew._train = False
|
||||||
|
return crew
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tools_handler() -> MagicMock:
|
||||||
|
"""Create a mock tools handler."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor(
|
||||||
|
mock_llm: MagicMock,
|
||||||
|
mock_agent: MagicMock,
|
||||||
|
mock_task: MagicMock,
|
||||||
|
mock_crew: MagicMock,
|
||||||
|
mock_tools_handler: MagicMock,
|
||||||
|
) -> CrewAgentExecutor:
|
||||||
|
"""Create a CrewAgentExecutor instance for testing."""
|
||||||
|
return CrewAgentExecutor(
|
||||||
|
llm=mock_llm,
|
||||||
|
task=mock_task,
|
||||||
|
crew=mock_crew,
|
||||||
|
agent=mock_agent,
|
||||||
|
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
|
||||||
|
max_iter=5,
|
||||||
|
tools=[],
|
||||||
|
tools_names="",
|
||||||
|
stop_words=["Observation:"],
|
||||||
|
tools_description="",
|
||||||
|
tools_handler=mock_tools_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncAgentExecutor:
|
||||||
|
"""Tests for async agent executor methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_returns_output(self, executor: CrewAgentExecutor) -> None:
|
||||||
|
"""Test that ainvoke returns the expected output."""
|
||||||
|
expected_output = "Final answer from agent"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
executor,
|
||||||
|
"_ainvoke_loop",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=AgentFinish(
|
||||||
|
thought="Done", output=expected_output, text="Final Answer: Done"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with patch.object(executor, "_show_start_logs"):
|
||||||
|
with patch.object(executor, "_create_short_term_memory"):
|
||||||
|
with patch.object(executor, "_create_long_term_memory"):
|
||||||
|
with patch.object(executor, "_create_external_memory"):
|
||||||
|
result = await executor.ainvoke(
|
||||||
|
{
|
||||||
|
"input": "test input",
|
||||||
|
"tool_names": "",
|
||||||
|
"tools": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"output": expected_output}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_loop_calls_aget_llm_response(
|
||||||
|
self, executor: CrewAgentExecutor
|
||||||
|
) -> None:
|
||||||
|
"""Test that _ainvoke_loop calls aget_llm_response."""
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="Thought: I know the answer\nFinal Answer: Test result",
|
||||||
|
) as mock_aget_llm:
|
||||||
|
with patch.object(executor, "_show_logs"):
|
||||||
|
result = await executor._ainvoke_loop()
|
||||||
|
|
||||||
|
mock_aget_llm.assert_called_once()
|
||||||
|
assert isinstance(result, AgentFinish)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_loop_handles_tool_execution(
|
||||||
|
self,
|
||||||
|
executor: CrewAgentExecutor,
|
||||||
|
) -> None:
|
||||||
|
"""Test that _ainvoke_loop handles tool execution asynchronously."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_llm_response(*args: Any, **kwargs: Any) -> str:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return (
|
||||||
|
"Thought: I need to use a tool\n"
|
||||||
|
"Action: test_tool\n"
|
||||||
|
'Action Input: {"arg": "value"}'
|
||||||
|
)
|
||||||
|
return "Thought: I have the answer\nFinal Answer: Tool result processed"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=mock_llm_response,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=ToolResult(result="Tool executed", result_as_answer=False),
|
||||||
|
) as mock_tool_exec:
|
||||||
|
with patch.object(executor, "_show_logs"):
|
||||||
|
with patch.object(executor, "_handle_agent_action") as mock_handle:
|
||||||
|
mock_handle.return_value = AgentAction(
|
||||||
|
text="Tool result",
|
||||||
|
tool="test_tool",
|
||||||
|
tool_input='{"arg": "value"}',
|
||||||
|
thought="Used tool",
|
||||||
|
result="Tool executed",
|
||||||
|
)
|
||||||
|
result = await executor._ainvoke_loop()
|
||||||
|
|
||||||
|
assert mock_tool_exec.called
|
||||||
|
assert isinstance(result, AgentFinish)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_loop_respects_max_iterations(
|
||||||
|
self, executor: CrewAgentExecutor
|
||||||
|
) -> None:
|
||||||
|
"""Test that _ainvoke_loop respects max iterations."""
|
||||||
|
executor.max_iter = 2
|
||||||
|
|
||||||
|
async def always_return_action(*args: Any, **kwargs: Any) -> str:
|
||||||
|
return (
|
||||||
|
"Thought: I need to think more\n"
|
||||||
|
"Action: some_tool\n"
|
||||||
|
"Action Input: {}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=always_return_action,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=ToolResult(result="Tool result", result_as_answer=False),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.handle_max_iterations_exceeded",
|
||||||
|
return_value=AgentFinish(
|
||||||
|
thought="Max iterations",
|
||||||
|
output="Forced answer",
|
||||||
|
text="Max iterations reached",
|
||||||
|
),
|
||||||
|
) as mock_max_iter:
|
||||||
|
with patch.object(executor, "_show_logs"):
|
||||||
|
with patch.object(executor, "_handle_agent_action") as mock_ha:
|
||||||
|
mock_ha.return_value = AgentAction(
|
||||||
|
text="Action",
|
||||||
|
tool="some_tool",
|
||||||
|
tool_input="{}",
|
||||||
|
thought="Thinking",
|
||||||
|
)
|
||||||
|
result = await executor._ainvoke_loop()
|
||||||
|
|
||||||
|
mock_max_iter.assert_called_once()
|
||||||
|
assert isinstance(result, AgentFinish)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ainvoke_handles_exceptions(
|
||||||
|
self, executor: CrewAgentExecutor
|
||||||
|
) -> None:
|
||||||
|
"""Test that ainvoke properly propagates exceptions."""
|
||||||
|
with patch.object(executor, "_show_start_logs"):
|
||||||
|
with patch.object(
|
||||||
|
executor,
|
||||||
|
"_ainvoke_loop",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ValueError("Test error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
|
await executor.ainvoke(
|
||||||
|
{"input": "test", "tool_names": "", "tools": ""}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_ainvoke_calls(
|
||||||
|
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
|
||||||
|
mock_crew: MagicMock, mock_tools_handler: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that multiple ainvoke calls can run concurrently."""
|
||||||
|
|
||||||
|
async def create_and_run_executor(executor_id: int) -> dict[str, Any]:
|
||||||
|
executor = CrewAgentExecutor(
|
||||||
|
llm=mock_llm,
|
||||||
|
task=mock_task,
|
||||||
|
crew=mock_crew,
|
||||||
|
agent=mock_agent,
|
||||||
|
prompt={"prompt": "Test {input} {tool_names} {tools}"},
|
||||||
|
max_iter=5,
|
||||||
|
tools=[],
|
||||||
|
tools_names="",
|
||||||
|
stop_words=["Observation:"],
|
||||||
|
tools_description="",
|
||||||
|
tools_handler=mock_tools_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delayed_response(*args: Any, **kwargs: Any) -> str:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
return f"Thought: Done\nFinal Answer: Result from executor {executor_id}"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai.agents.crew_agent_executor.aget_llm_response",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=delayed_response,
|
||||||
|
):
|
||||||
|
with patch.object(executor, "_show_start_logs"):
|
||||||
|
with patch.object(executor, "_show_logs"):
|
||||||
|
with patch.object(executor, "_create_short_term_memory"):
|
||||||
|
with patch.object(executor, "_create_long_term_memory"):
|
||||||
|
with patch.object(executor, "_create_external_memory"):
|
||||||
|
return await executor.ainvoke(
|
||||||
|
{
|
||||||
|
"input": f"test {executor_id}",
|
||||||
|
"tool_names": "",
|
||||||
|
"tools": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
results = await asyncio.gather(
|
||||||
|
create_and_run_executor(1),
|
||||||
|
create_and_run_executor(2),
|
||||||
|
create_and_run_executor(3),
|
||||||
|
)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all("output" in r for r in results)
|
||||||
|
assert elapsed < 0.15, f"Expected concurrent execution, took {elapsed}s"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncLLMResponseHelper:
|
||||||
|
"""Tests for aget_llm_response helper function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aget_llm_response_calls_acall(self) -> None:
|
||||||
|
"""Test that aget_llm_response calls llm.acall."""
|
||||||
|
from crewai.utilities.agent_utils import aget_llm_response
|
||||||
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.acall = AsyncMock(return_value="LLM response")
|
||||||
|
|
||||||
|
result = await aget_llm_response(
|
||||||
|
llm=mock_llm,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
callbacks=[],
|
||||||
|
printer=Printer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llm.acall.assert_called_once()
|
||||||
|
assert result == "LLM response"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aget_llm_response_raises_on_empty_response(self) -> None:
|
||||||
|
"""Test that aget_llm_response raises ValueError on empty response."""
|
||||||
|
from crewai.utilities.agent_utils import aget_llm_response
|
||||||
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.acall = AsyncMock(return_value="")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid response from LLM call"):
|
||||||
|
await aget_llm_response(
|
||||||
|
llm=mock_llm,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
callbacks=[],
|
||||||
|
printer=Printer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aget_llm_response_propagates_exceptions(self) -> None:
|
||||||
|
"""Test that aget_llm_response propagates LLM exceptions."""
|
||||||
|
from crewai.utilities.agent_utils import aget_llm_response
|
||||||
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.acall = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="LLM error"):
|
||||||
|
await aget_llm_response(
|
||||||
|
llm=mock_llm,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
callbacks=[],
|
||||||
|
printer=Printer(),
|
||||||
|
)
|
||||||
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
384
lib/crewai/tests/crew/test_async_crew.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
"""Tests for async crew execution."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.crews.crew_output import CrewOutput
|
||||||
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_agent() -> Agent:
|
||||||
|
"""Create a test agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_task(test_agent: Agent) -> Task:
|
||||||
|
"""Create a test task."""
|
||||||
|
return Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_crew(test_agent: Agent, test_task: Task) -> Crew:
|
||||||
|
"""Create a test crew."""
|
||||||
|
return Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[test_task],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCrewKickoff:
|
||||||
|
"""Tests for async crew kickoff methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_basic(
|
||||||
|
self, mock_execute: AsyncMock, test_crew: Crew
|
||||||
|
) -> None:
|
||||||
|
"""Test basic async crew kickoff."""
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test task description",
|
||||||
|
raw="Task result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.return_value = mock_output
|
||||||
|
|
||||||
|
result = await test_crew.akickoff()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, CrewOutput)
|
||||||
|
assert result.raw == "Task result"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_with_inputs(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async crew kickoff with inputs."""
|
||||||
|
task = Task(
|
||||||
|
description="Test task for {topic}",
|
||||||
|
expected_output="Expected output for {topic}",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test task for AI",
|
||||||
|
raw="Task result about AI",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.return_value = mock_output
|
||||||
|
|
||||||
|
result = await crew.akickoff(inputs={"topic": "AI"})
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, CrewOutput)
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_multiple_tasks(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async crew kickoff with multiple tasks."""
|
||||||
|
task1 = Task(
|
||||||
|
description="First task",
|
||||||
|
expected_output="First output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
task2 = Task(
|
||||||
|
description="Second task",
|
||||||
|
expected_output="Second output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task1, task2],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output1 = TaskOutput(
|
||||||
|
description="First task",
|
||||||
|
raw="First result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_output2 = TaskOutput(
|
||||||
|
description="Second task",
|
||||||
|
raw="Second result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||||
|
|
||||||
|
result = await crew.akickoff()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, CrewOutput)
|
||||||
|
assert result.raw == "Second result"
|
||||||
|
assert mock_execute.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_handles_exception(
|
||||||
|
self, mock_execute: AsyncMock, test_crew: Crew
|
||||||
|
) -> None:
|
||||||
|
"""Test that async kickoff handles exceptions properly."""
|
||||||
|
mock_execute.side_effect = RuntimeError("Test error")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
await test_crew.akickoff()
|
||||||
|
|
||||||
|
assert "Test error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_calls_before_callbacks(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async kickoff calls before_kickoff_callbacks."""
|
||||||
|
callback_called = False
|
||||||
|
|
||||||
|
def before_callback(inputs: dict | None) -> dict:
|
||||||
|
nonlocal callback_called
|
||||||
|
callback_called = True
|
||||||
|
return inputs or {}
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=False,
|
||||||
|
before_kickoff_callbacks=[before_callback],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test task",
|
||||||
|
raw="Task result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.return_value = mock_output
|
||||||
|
|
||||||
|
await crew.akickoff()
|
||||||
|
|
||||||
|
assert callback_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_calls_after_callbacks(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async kickoff calls after_kickoff_callbacks."""
|
||||||
|
callback_called = False
|
||||||
|
|
||||||
|
def after_callback(result: CrewOutput) -> CrewOutput:
|
||||||
|
nonlocal callback_called
|
||||||
|
callback_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=False,
|
||||||
|
after_kickoff_callbacks=[after_callback],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test task",
|
||||||
|
raw="Task result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.return_value = mock_output
|
||||||
|
|
||||||
|
await crew.akickoff()
|
||||||
|
|
||||||
|
assert callback_called
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCrewKickoffForEach:
|
||||||
|
"""Tests for async crew kickoff_for_each methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_for_each_basic(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test basic async kickoff_for_each."""
|
||||||
|
task = Task(
|
||||||
|
description="Test task for {topic}",
|
||||||
|
expected_output="Expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output1 = TaskOutput(
|
||||||
|
description="Test task for AI",
|
||||||
|
raw="Result about AI",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_output2 = TaskOutput(
|
||||||
|
description="Test task for ML",
|
||||||
|
raw="Result about ML",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||||
|
|
||||||
|
inputs = [{"topic": "AI"}, {"topic": "ML"}]
|
||||||
|
results = await crew.akickoff_for_each(inputs)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert all(isinstance(r, CrewOutput) for r in results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_akickoff_for_each_concurrent(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async kickoff_for_each runs concurrently."""
|
||||||
|
task = Task(
|
||||||
|
description="Test task for {topic}",
|
||||||
|
expected_output="Expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output = TaskOutput(
|
||||||
|
description="Test task",
|
||||||
|
raw="Result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.return_value = mock_output
|
||||||
|
|
||||||
|
inputs = [{"topic": f"topic_{i}"} for i in range(3)]
|
||||||
|
results = await crew.akickoff_for_each(inputs)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskExecution:
|
||||||
|
"""Tests for async task execution within crew."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_tasks_sequential(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async sequential task execution."""
|
||||||
|
task1 = Task(
|
||||||
|
description="First task",
|
||||||
|
expected_output="First output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
task2 = Task(
|
||||||
|
description="Second task",
|
||||||
|
expected_output="Second output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task1, task2],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output1 = TaskOutput(
|
||||||
|
description="First task",
|
||||||
|
raw="First result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_output2 = TaskOutput(
|
||||||
|
description="Second task",
|
||||||
|
raw="Second result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||||
|
|
||||||
|
result = await crew._aexecute_tasks(crew.tasks)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.raw == "Second result"
|
||||||
|
assert len(result.tasks_output) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.task.Task.aexecute_sync", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_tasks_with_async_task(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with async_execution task flag."""
|
||||||
|
task1 = Task(
|
||||||
|
description="Async task",
|
||||||
|
expected_output="Async output",
|
||||||
|
agent=test_agent,
|
||||||
|
async_execution=True,
|
||||||
|
)
|
||||||
|
task2 = Task(
|
||||||
|
description="Sync task",
|
||||||
|
expected_output="Sync output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
crew = Crew(
|
||||||
|
agents=[test_agent],
|
||||||
|
tasks=[task1, task2],
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_output1 = TaskOutput(
|
||||||
|
description="Async task",
|
||||||
|
raw="Async result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_output2 = TaskOutput(
|
||||||
|
description="Sync task",
|
||||||
|
raw="Sync result",
|
||||||
|
agent="Test Agent",
|
||||||
|
)
|
||||||
|
mock_execute.side_effect = [mock_output1, mock_output2]
|
||||||
|
|
||||||
|
result = await crew._aexecute_tasks(crew.tasks)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert mock_execute.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncProcessAsyncTasks:
|
||||||
|
"""Tests for _aprocess_async_tasks method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aprocess_async_tasks_empty(self, test_crew: Crew) -> None:
|
||||||
|
"""Test processing empty list of async tasks."""
|
||||||
|
result = await test_crew._aprocess_async_tasks([])
|
||||||
|
assert result == []
|
||||||
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
212
lib/crewai/tests/knowledge/test_async_knowledge.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Tests for async knowledge operations."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
|
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||||
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncKnowledgeStorage:
|
||||||
|
"""Tests for async KnowledgeStorage operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_returns_results(self):
|
||||||
|
"""Test that asearch returns search results."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.asearch = AsyncMock(
|
||||||
|
return_value=[{"content": "test result", "score": 0.9}]
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
results = await storage.asearch(["test query"])
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["content"] == "test result"
|
||||||
|
mock_client.asearch.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_empty_query_raises_error(self):
|
||||||
|
"""Test that asearch handles empty query."""
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
|
||||||
|
# Empty query should not raise but return empty results due to error handling
|
||||||
|
results = await storage.asearch([])
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_calls_client_methods(self):
|
||||||
|
"""Test that asave calls the correct client methods."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.aget_or_create_collection = AsyncMock()
|
||||||
|
mock_client.aadd_documents = AsyncMock()
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
await storage.asave(["document 1", "document 2"])
|
||||||
|
|
||||||
|
mock_client.aget_or_create_collection.assert_called_once_with(
|
||||||
|
collection_name="knowledge_test_collection"
|
||||||
|
)
|
||||||
|
mock_client.aadd_documents.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_calls_client_delete(self):
|
||||||
|
"""Test that areset calls delete_collection on the client."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.adelete_collection = AsyncMock()
|
||||||
|
|
||||||
|
storage = KnowledgeStorage(collection_name="test_collection")
|
||||||
|
storage._client = mock_client
|
||||||
|
|
||||||
|
await storage.areset()
|
||||||
|
|
||||||
|
mock_client.adelete_collection.assert_called_once_with(
|
||||||
|
collection_name="knowledge_test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncKnowledge:
|
||||||
|
"""Tests for async Knowledge operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aquery_calls_storage_asearch(self):
|
||||||
|
"""Test that aquery calls storage.asearch."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asearch = AsyncMock(
|
||||||
|
return_value=[{"content": "result", "score": 0.8}]
|
||||||
|
)
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await knowledge.aquery(["test query"])
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
mock_storage.asearch.assert_called_once_with(
|
||||||
|
["test query"],
|
||||||
|
limit=5,
|
||||||
|
score_threshold=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aquery_raises_when_storage_not_initialized(self):
|
||||||
|
"""Test that aquery raises ValueError when storage is None."""
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=MagicMock(spec=KnowledgeStorage),
|
||||||
|
)
|
||||||
|
knowledge.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||||
|
await knowledge.aquery(["test query"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_sources_calls_source_aadd(self):
|
||||||
|
"""Test that aadd_sources calls aadd on each source."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.aadd = AsyncMock()
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[mock_source],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
await knowledge.aadd_sources()
|
||||||
|
|
||||||
|
mock_source.aadd.assert_called_once()
|
||||||
|
assert mock_source.storage == mock_storage
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_calls_storage_areset(self):
|
||||||
|
"""Test that areset calls storage.areset."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.areset = AsyncMock()
|
||||||
|
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
await knowledge.areset()
|
||||||
|
|
||||||
|
mock_storage.areset.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset_raises_when_storage_not_initialized(self):
|
||||||
|
"""Test that areset raises ValueError when storage is None."""
|
||||||
|
knowledge = Knowledge(
|
||||||
|
collection_name="test",
|
||||||
|
sources=[],
|
||||||
|
storage=MagicMock(spec=KnowledgeStorage),
|
||||||
|
)
|
||||||
|
knowledge.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Storage is not initialized"):
|
||||||
|
await knowledge.areset()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncStringKnowledgeSource:
|
||||||
|
"""Tests for async StringKnowledgeSource operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_saves_documents_asynchronously(self):
|
||||||
|
"""Test that aadd chunks and saves documents asynchronously."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asave = AsyncMock()
|
||||||
|
|
||||||
|
source = StringKnowledgeSource(content="Test content for async processing")
|
||||||
|
source.storage = mock_storage
|
||||||
|
|
||||||
|
await source.aadd()
|
||||||
|
|
||||||
|
mock_storage.asave.assert_called_once()
|
||||||
|
assert len(source.chunks) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_raises_without_storage(self):
|
||||||
|
"""Test that aadd raises ValueError when storage is not set."""
|
||||||
|
source = StringKnowledgeSource(content="Test content")
|
||||||
|
source.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No storage found"):
|
||||||
|
await source.aadd()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncBaseKnowledgeSource:
|
||||||
|
"""Tests for async _asave_documents method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_documents_calls_storage_asave(self):
|
||||||
|
"""Test that _asave_documents calls storage.asave."""
|
||||||
|
mock_storage = MagicMock(spec=KnowledgeStorage)
|
||||||
|
mock_storage.asave = AsyncMock()
|
||||||
|
|
||||||
|
source = StringKnowledgeSource(content="Test")
|
||||||
|
source.storage = mock_storage
|
||||||
|
source.chunks = ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
await source._asave_documents()
|
||||||
|
|
||||||
|
mock_storage.asave.assert_called_once_with(["chunk1", "chunk2"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_documents_raises_without_storage(self):
|
||||||
|
"""Test that _asave_documents raises ValueError when storage is None."""
|
||||||
|
source = StringKnowledgeSource(content="Test")
|
||||||
|
source.storage = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No storage found"):
|
||||||
|
await source._asave_documents()
|
||||||
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
"""Tests for async memory operations."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
|
from crewai.memory.external.external_memory import ExternalMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Fixture to create a mock agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Search relevant data and provide results",
|
||||||
|
backstory="You are a researcher at a leading tech think tank.",
|
||||||
|
tools=[],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task(mock_agent):
|
||||||
|
"""Fixture to create a mock task."""
|
||||||
|
return Task(
|
||||||
|
description="Perform a search on specific topics.",
|
||||||
|
expected_output="A list of relevant URLs based on the search query.",
|
||||||
|
agent=mock_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def short_term_memory(mock_agent, mock_task):
|
||||||
|
"""Fixture to create a ShortTermMemory instance."""
|
||||||
|
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def long_term_memory(tmp_path):
|
||||||
|
"""Fixture to create a LongTermMemory instance."""
|
||||||
|
db_path = str(tmp_path / "test_ltm.db")
|
||||||
|
return LongTermMemory(path=db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||||
|
"""Fixture to create an EntityMemory instance."""
|
||||||
|
return EntityMemory(
|
||||||
|
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||||
|
path=str(tmp_path / "test_entities"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncShortTermMemory:
|
||||||
|
"""Tests for async ShortTermMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_emits_events(self, short_term_memory):
|
||||||
|
"""Test that asave emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
await short_term_memory.asave(
|
||||||
|
value="async test value",
|
||||||
|
metadata={"task": "async_test_task"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, short_term_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await short_term_memory.asearch(
|
||||||
|
query="async test query",
|
||||||
|
limit=3,
|
||||||
|
score_threshold=0.35,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncLongTermMemory:
|
||||||
|
"""Tests for async LongTermMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_emits_events(self, long_term_memory):
|
||||||
|
"""Test that asave emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
item = LongTermMemoryItem(
|
||||||
|
task="async test task",
|
||||||
|
agent="test_agent",
|
||||||
|
expected_output="test output",
|
||||||
|
datetime="2024-01-01T00:00:00",
|
||||||
|
quality=0.9,
|
||||||
|
metadata={"task": "async test task", "quality": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
await long_term_memory.asave(item)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, long_term_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||||
|
"""Test that asave followed by asearch works correctly."""
|
||||||
|
item = LongTermMemoryItem(
|
||||||
|
task="integration test task",
|
||||||
|
agent="test_agent",
|
||||||
|
expected_output="test output",
|
||||||
|
datetime="2024-01-01T00:00:00",
|
||||||
|
quality=0.9,
|
||||||
|
metadata={"task": "integration test task", "quality": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
await long_term_memory.asave(item)
|
||||||
|
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||||
|
|
||||||
|
assert results is not None
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncEntityMemory:
|
||||||
|
"""Tests for async EntityMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||||
|
"""Test that asave with a single item emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
item = EntityMemoryItem(
|
||||||
|
name="TestEntity",
|
||||||
|
type="Person",
|
||||||
|
description="A test entity for async operations",
|
||||||
|
relationships="Related to other test entities",
|
||||||
|
)
|
||||||
|
|
||||||
|
await entity_memory.asave(item)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, entity_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncContextualMemory:
|
||||||
|
"""Tests for async ContextualMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||||
|
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||||
|
mock_task.description = ""
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||||
|
"""Test that abuild_context_for_task handles None memory sources."""
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||||
|
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(
|
||||||
|
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||||
|
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=mock_em,
|
||||||
|
exm=mock_exm,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||||
|
|
||||||
|
assert "Recent Insights:" in result
|
||||||
|
assert "STM insight" in result
|
||||||
|
assert "Historical Data:" in result
|
||||||
|
assert "LTM suggestion" in result
|
||||||
|
assert "Entities:" in result
|
||||||
|
assert "Entity info" in result
|
||||||
|
assert "External memories:" in result
|
||||||
|
assert "External memory" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "First insight"},
|
||||||
|
{"content": "Second insight"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_stm_context("test query")
|
||||||
|
|
||||||
|
assert "Recent Insights:" in result
|
||||||
|
assert "- First insight" in result
|
||||||
|
assert "- Second insight" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_ltm_context("test task")
|
||||||
|
|
||||||
|
assert "Historical Data:" in result
|
||||||
|
assert "- Suggestion 1" in result
|
||||||
|
assert "- Suggestion 2" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "Entity A details"},
|
||||||
|
{"content": "Entity B details"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=mock_em,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_entity_context("test query")
|
||||||
|
|
||||||
|
assert "Entities:" in result
|
||||||
|
assert "- Entity A details" in result
|
||||||
|
assert "- Entity B details" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_external_context_returns_formatted_results(self):
|
||||||
|
"""Test that _afetch_external_context returns properly formatted results."""
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "External data 1"},
|
||||||
|
{"content": "External data 2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=mock_exm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_external_context("test query")
|
||||||
|
|
||||||
|
assert "External memories:" in result
|
||||||
|
assert "- External data 1" in result
|
||||||
|
assert "- External data 2" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||||
|
"""Test that async fetch methods return empty string for no results."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=mock_em,
|
||||||
|
exm=mock_exm,
|
||||||
|
)
|
||||||
|
|
||||||
|
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||||
|
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||||
|
em_result = await contextual_memory._afetch_entity_context("query")
|
||||||
|
exm_result = await contextual_memory._afetch_external_context("query")
|
||||||
|
|
||||||
|
assert stm_result == ""
|
||||||
|
assert ltm_result is None
|
||||||
|
assert em_result == ""
|
||||||
|
assert exm_result == ""
|
||||||
386
lib/crewai/tests/task/test_async_task.py
Normal file
386
lib/crewai/tests/task/test_async_task.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
"""Tests for async task execution."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
from crewai.tasks.output_format import OutputFormat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_agent() -> Agent:
|
||||||
|
"""Create a test agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskExecution:
|
||||||
|
"""Tests for async task execution methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_basic(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test basic async task execution."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, TaskOutput)
|
||||||
|
assert result.raw == "Async task result"
|
||||||
|
assert result.agent == "Test Agent"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_context(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async task execution with context."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = "Additional context for the task"
|
||||||
|
result = await task.aexecute_sync(context=context)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert task.prompt_context == context
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
call_kwargs = mock_execute.call_args[1]
|
||||||
|
assert call_kwargs["context"] == context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_tools(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async task execution with custom tools."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
|
||||||
|
result = await task.aexecute_sync(tools=[mock_tool])
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
call_kwargs = mock_execute.call_args[1]
|
||||||
|
assert mock_tool in call_kwargs["tools"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_sets_start_and_end_time(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution sets start and end times."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.start_time is None
|
||||||
|
assert task.end_time is None
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert task.start_time is not None
|
||||||
|
assert task.end_time is not None
|
||||||
|
assert task.end_time >= task.start_time
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_stores_output(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution stores the output."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.output is None
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert task.output is not None
|
||||||
|
assert task.output.raw == "Async task result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_adds_agent_to_processed_by(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution adds agent to processed_by_agents."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(task.processed_by_agents) == 0
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "Test Agent" in task.processed_by_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_calls_callback(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution calls the callback."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
callback = MagicMock()
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
callback.assert_called_once()
|
||||||
|
assert isinstance(callback.call_args[0][0], TaskOutput)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aexecute_sync_without_agent_raises(self) -> None:
|
||||||
|
"""Test that async execution without agent raises exception."""
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "has no agent assigned" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_with_different_agent(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with a different agent than assigned."""
|
||||||
|
mock_execute.return_value = "Other agent result"
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
other_agent = Agent(
|
||||||
|
role="Other Agent",
|
||||||
|
goal="Other goal",
|
||||||
|
backstory="Other backstory",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync(agent=other_agent)
|
||||||
|
|
||||||
|
assert result.raw == "Other agent result"
|
||||||
|
assert result.agent == "Other Agent"
|
||||||
|
mock_execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_handles_exception(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that async execution handles exceptions properly."""
|
||||||
|
mock_execute.side_effect = RuntimeError("Test error")
|
||||||
|
task = Task(
|
||||||
|
description="Test task description",
|
||||||
|
expected_output="Test expected output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "Test error" in str(exc_info.value)
|
||||||
|
assert task.end_time is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncGuardrails:
|
||||||
|
"""Tests for async guardrail invocation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_success(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail invocation with successful validation."""
|
||||||
|
mock_execute.return_value = "Async task result"
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.raw == "Async task result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_failure_then_success(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail that fails then succeeds on retry."""
|
||||||
|
mock_execute.side_effect = ["First result", "Second result"]
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return False, "First attempt failed"
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_guardrail_max_retries_exceeded(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async guardrail that exceeds max retries."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
|
||||||
|
def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
return False, "Always fails"
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
guardrail_max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert "validation after" in str(exc_info.value)
|
||||||
|
assert "2 retries" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_ainvoke_multiple_guardrails(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with multiple guardrails."""
|
||||||
|
mock_execute.return_value = "Async result"
|
||||||
|
guardrail1_called = False
|
||||||
|
guardrail2_called = False
|
||||||
|
|
||||||
|
def guardrail1(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal guardrail1_called
|
||||||
|
guardrail1_called = True
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
def guardrail2(output: TaskOutput) -> tuple[bool, str]:
|
||||||
|
nonlocal guardrail2_called
|
||||||
|
guardrail2_called = True
|
||||||
|
return True, output.raw
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
guardrails=[guardrail1, guardrail2],
|
||||||
|
)
|
||||||
|
|
||||||
|
await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert guardrail1_called
|
||||||
|
assert guardrail2_called
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskOutput:
|
||||||
|
"""Tests for async task output handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_output_format_raw(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test async execution with raw output format."""
|
||||||
|
mock_execute.return_value = '{"key": "value"}'
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=test_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result.output_format == OutputFormat.RAW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
|
||||||
|
async def test_aexecute_sync_task_output_attributes(
|
||||||
|
self, mock_execute: AsyncMock, test_agent: Agent
|
||||||
|
) -> None:
|
||||||
|
"""Test that task output has correct attributes."""
|
||||||
|
mock_execute.return_value = "Test result"
|
||||||
|
task = Task(
|
||||||
|
description="Test description",
|
||||||
|
expected_output="Test expected",
|
||||||
|
agent=test_agent,
|
||||||
|
name="Test Task Name",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await task.aexecute_sync()
|
||||||
|
|
||||||
|
assert result.name == "Test Task Name"
|
||||||
|
assert result.description == "Test description"
|
||||||
|
assert result.expected_output == "Test expected"
|
||||||
|
assert result.raw == "Test result"
|
||||||
|
assert result.agent == "Test Agent"
|
||||||
@@ -1492,3 +1492,144 @@ def test_flow_copy_state_with_dict_state():
|
|||||||
|
|
||||||
flow.state["test"] = "modified"
|
flow.state["test"] = "modified"
|
||||||
assert copied_state["test"] == "value"
|
assert copied_state["test"] == "value"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlowAkickoff:
|
||||||
|
"""Tests for the native async akickoff method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_basic(self):
|
||||||
|
"""Test basic akickoff execution."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class SimpleFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
return "step_1_result"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self, result):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
return "final_result"
|
||||||
|
|
||||||
|
flow = SimpleFlow()
|
||||||
|
result = await flow.akickoff()
|
||||||
|
|
||||||
|
assert execution_order == ["step_1", "step_2"]
|
||||||
|
assert result == "final_result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_with_inputs(self):
|
||||||
|
"""Test akickoff with inputs."""
|
||||||
|
|
||||||
|
class InputFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def process_input(self):
|
||||||
|
return self.state.get("value", "default")
|
||||||
|
|
||||||
|
flow = InputFlow()
|
||||||
|
result = await flow.akickoff(inputs={"value": "custom_value"})
|
||||||
|
|
||||||
|
assert result == "custom_value"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_with_async_methods(self):
|
||||||
|
"""Test akickoff with async flow methods."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class AsyncMethodFlow(Flow):
|
||||||
|
@start()
|
||||||
|
async def async_step_1(self):
|
||||||
|
execution_order.append("async_step_1")
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "async_result"
|
||||||
|
|
||||||
|
@listen(async_step_1)
|
||||||
|
async def async_step_2(self, result):
|
||||||
|
execution_order.append("async_step_2")
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return f"final_{result}"
|
||||||
|
|
||||||
|
flow = AsyncMethodFlow()
|
||||||
|
result = await flow.akickoff()
|
||||||
|
|
||||||
|
assert execution_order == ["async_step_1", "async_step_2"]
|
||||||
|
assert result == "final_async_result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_equivalent_to_kickoff_async(self):
|
||||||
|
"""Test that akickoff produces the same results as kickoff_async."""
|
||||||
|
execution_order_akickoff = []
|
||||||
|
execution_order_kickoff_async = []
|
||||||
|
|
||||||
|
class TestFlow(Flow):
|
||||||
|
def __init__(self, execution_list):
|
||||||
|
super().__init__()
|
||||||
|
self._execution_list = execution_list
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self._execution_list.append("step_1")
|
||||||
|
return "result_1"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self, result):
|
||||||
|
self._execution_list.append("step_2")
|
||||||
|
return "result_2"
|
||||||
|
|
||||||
|
flow1 = TestFlow(execution_order_akickoff)
|
||||||
|
result1 = await flow1.akickoff()
|
||||||
|
|
||||||
|
flow2 = TestFlow(execution_order_kickoff_async)
|
||||||
|
result2 = await flow2.kickoff_async()
|
||||||
|
|
||||||
|
assert execution_order_akickoff == execution_order_kickoff_async
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_with_multiple_starts(self):
|
||||||
|
"""Test akickoff with multiple start methods."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class MultiStartFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start_a(self):
|
||||||
|
execution_order.append("start_a")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def start_b(self):
|
||||||
|
execution_order.append("start_b")
|
||||||
|
|
||||||
|
flow = MultiStartFlow()
|
||||||
|
await flow.akickoff()
|
||||||
|
|
||||||
|
assert "start_a" in execution_order
|
||||||
|
assert "start_b" in execution_order
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_akickoff_with_router(self):
|
||||||
|
"""Test akickoff with router method."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class RouterFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def begin(self):
|
||||||
|
execution_order.append("begin")
|
||||||
|
return "data"
|
||||||
|
|
||||||
|
@router(begin)
|
||||||
|
def route(self, data):
|
||||||
|
execution_order.append("route")
|
||||||
|
return "PATH_A"
|
||||||
|
|
||||||
|
@listen("PATH_A")
|
||||||
|
def handle_path_a(self):
|
||||||
|
execution_order.append("path_a")
|
||||||
|
return "path_a_result"
|
||||||
|
|
||||||
|
flow = RouterFlow()
|
||||||
|
result = await flow.akickoff()
|
||||||
|
|
||||||
|
assert execution_order == ["begin", "route", "path_a"]
|
||||||
|
assert result == "path_a_result"
|
||||||
|
|||||||
52
uv.lock
generated
52
uv.lock
generated
@@ -247,6 +247,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aiosqlite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-doc"
|
name = "annotated-doc"
|
||||||
version = "0.0.4"
|
version = "0.0.4"
|
||||||
@@ -1100,6 +1112,7 @@ wheels = [
|
|||||||
name = "crewai"
|
name = "crewai"
|
||||||
source = { editable = "lib/crewai" }
|
source = { editable = "lib/crewai" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "aiosqlite" },
|
||||||
{ name = "appdirs" },
|
{ name = "appdirs" },
|
||||||
{ name = "chromadb" },
|
{ name = "chromadb" },
|
||||||
{ name = "click" },
|
{ name = "click" },
|
||||||
@@ -1183,6 +1196,7 @@ watson = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
||||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||||
|
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
||||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||||
@@ -1843,7 +1857,7 @@ name = "exceptiongroup"
|
|||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -4337,7 +4351,7 @@ name = "nvidia-cudnn-cu12"
|
|||||||
version = "9.10.2.21"
|
version = "9.10.2.21"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12" },
|
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||||
@@ -4348,7 +4362,7 @@ name = "nvidia-cufft-cu12"
|
|||||||
version = "11.3.3.83"
|
version = "11.3.3.83"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||||
@@ -4375,9 +4389,9 @@ name = "nvidia-cusolver-cu12"
|
|||||||
version = "11.7.3.90"
|
version = "11.7.3.90"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12" },
|
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "nvidia-cusparse-cu12" },
|
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "nvidia-nvjitlink-cu12" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||||
@@ -4388,7 +4402,7 @@ name = "nvidia-cusparse-cu12"
|
|||||||
version = "12.5.8.93"
|
version = "12.5.8.93"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||||
@@ -4448,9 +4462,9 @@ name = "ocrmac"
|
|||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "click" },
|
{ name = "click", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-vision" },
|
{ name = "pyobjc-framework-vision", marker = "sys_platform == 'darwin'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/dd/dc/de3e9635774b97d9766f6815bbb3f5ec9bce347115f10d9abbf2733a9316/ocrmac-1.0.0.tar.gz", hash = "sha256:5b299e9030c973d1f60f82db000d6c2e5ff271601878c7db0885e850597d1d2e", size = 1463997, upload-time = "2024-11-07T12:00:00.197Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/dd/dc/de3e9635774b97d9766f6815bbb3f5ec9bce347115f10d9abbf2733a9316/ocrmac-1.0.0.tar.gz", hash = "sha256:5b299e9030c973d1f60f82db000d6c2e5ff271601878c7db0885e850597d1d2e", size = 1463997, upload-time = "2024-11-07T12:00:00.197Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -6050,7 +6064,7 @@ name = "pyobjc-framework-cocoa"
|
|||||||
version = "12.1"
|
version = "12.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pyobjc-core" },
|
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -6066,8 +6080,8 @@ name = "pyobjc-framework-coreml"
|
|||||||
version = "12.1"
|
version = "12.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pyobjc-core" },
|
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-cocoa" },
|
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/30/2d/baa9ea02cbb1c200683cb7273b69b4bee5070e86f2060b77e6a27c2a9d7e/pyobjc_framework_coreml-12.1.tar.gz", hash = "sha256:0d1a4216891a18775c9e0170d908714c18e4f53f9dc79fb0f5263b2aa81609ba", size = 40465, upload-time = "2025-11-14T10:14:02.265Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/30/2d/baa9ea02cbb1c200683cb7273b69b4bee5070e86f2060b77e6a27c2a9d7e/pyobjc_framework_coreml-12.1.tar.gz", hash = "sha256:0d1a4216891a18775c9e0170d908714c18e4f53f9dc79fb0f5263b2aa81609ba", size = 40465, upload-time = "2025-11-14T10:14:02.265Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -6083,8 +6097,8 @@ name = "pyobjc-framework-quartz"
|
|||||||
version = "12.1"
|
version = "12.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pyobjc-core" },
|
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-cocoa" },
|
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -6100,10 +6114,10 @@ name = "pyobjc-framework-vision"
|
|||||||
version = "12.1"
|
version = "12.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pyobjc-core" },
|
{ name = "pyobjc-core", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-cocoa" },
|
{ name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-coreml" },
|
{ name = "pyobjc-framework-coreml", marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "pyobjc-framework-quartz" },
|
{ name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/c2/5a/08bb3e278f870443d226c141af14205ff41c0274da1e053b72b11dfc9fb2/pyobjc_framework_vision-12.1.tar.gz", hash = "sha256:a30959100e85dcede3a786c544e621ad6eb65ff6abf85721f805822b8c5fe9b0", size = 59538, upload-time = "2025-11-14T10:23:21.979Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/c2/5a/08bb3e278f870443d226c141af14205ff41c0274da1e053b72b11dfc9fb2/pyobjc_framework_vision-12.1.tar.gz", hash = "sha256:a30959100e85dcede3a786c544e621ad6eb65ff6abf85721f805822b8c5fe9b0", size = 59538, upload-time = "2025-11-14T10:23:21.979Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
|||||||
Reference in New Issue
Block a user