Compare commits

..

2 Commits

25 changed files with 357 additions and 405 deletions

View File

@@ -9,6 +9,7 @@ from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
from crewai.utilities import Logger
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
@@ -62,8 +63,12 @@ class Agent(BaseAgent):
tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
allow_feedback: Whether the agent can receive and process feedback during execution.
allow_conflict: Whether the agent can handle conflicts with other agents during execution.
allow_iteration: Whether the agent can iterate on its solutions based on feedback and validation.
"""
_logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
default=None,
@@ -123,6 +128,18 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
allow_feedback: bool = Field(
default=False,
description="Enable agent to receive and process feedback during execution.",
)
allow_conflict: bool = Field(
default=False,
description="Enable agent to handle conflicts with other agents during execution.",
)
allow_iteration: bool = Field(
default=False,
description="Enable agent to iterate on its solutions based on feedback and validation.",
)
embedder_config: Optional[Dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
@@ -139,6 +156,19 @@ class Agent(BaseAgent):
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
if self.allow_feedback:
self._logger.log("info", "Feedback mode enabled for agent.", color="bold_green")
if self.allow_conflict:
self._logger.log("info", "Conflict handling enabled for agent.", color="bold_green")
if self.allow_iteration:
self._logger.log("info", "Iteration mode enabled for agent.", color="bold_green")
# Validate boolean parameters
for param in ['allow_feedback', 'allow_conflict', 'allow_iteration']:
if not isinstance(getattr(self, param), bool):
raise ValueError(f"Parameter '{param}' must be a boolean value.")
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
@@ -294,7 +324,14 @@ class Agent(BaseAgent):
)
if self.crew and self.crew.memory:
memory = self.crew.contextual_memory.build_context_for_task(task, context)
contextual_memory = ContextualMemory(
self.crew.memory_config,
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._user_memory,
)
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
@@ -393,6 +430,9 @@ class Agent(BaseAgent):
step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
allow_feedback=self.allow_feedback,
allow_conflict=self.allow_conflict,
allow_iteration=self.allow_iteration,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
),

View File

@@ -31,6 +31,34 @@ class ToolResult:
class CrewAgentExecutor(CrewAgentExecutorMixin):
"""CrewAgentExecutor class for managing agent execution.
This class is responsible for executing agent tasks, handling tools,
managing agent interactions, and processing the results.
Parameters:
llm: The language model to use for generating responses.
task: The task to be executed.
crew: The crew that the agent belongs to.
agent: The agent to execute the task.
prompt: The prompt to use for generating responses.
max_iter: Maximum number of iterations for the agent execution.
tools: The tools available to the agent.
tools_names: The names of the tools available to the agent.
stop_words: Words that signal the end of agent execution.
tools_description: Description of the tools available to the agent.
tools_handler: Handler for tool operations.
step_callback: Callback function for each step of execution.
original_tools: Original list of tools before processing.
function_calling_llm: LLM specifically for function calling.
respect_context_window: Whether to respect the context window size.
request_within_rpm_limit: Function to check if request is within RPM limit.
callbacks: List of callback functions.
allow_feedback: Controls feedback processing during execution.
allow_conflict: Enables conflict handling between agents.
allow_iteration: Allows solution iteration based on feedback.
"""
_logger: Logger = Logger()
def __init__(
@@ -52,6 +80,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window: bool = False,
request_within_rpm_limit: Any = None,
callbacks: List[Any] = [],
allow_feedback: bool = False,
allow_conflict: bool = False,
allow_iteration: bool = False,
):
self._i18n: I18N = I18N()
self.llm = llm
@@ -73,6 +104,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.allow_feedback = allow_feedback
self.allow_conflict = allow_conflict
self.allow_iteration = allow_iteration
self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = []
self.iterations = 0
@@ -358,9 +392,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration
if agent_id in training_data and isinstance(train_iteration, int):
training_data[agent_id][train_iteration]["improved_output"] = (
result.output
)
training_data[agent_id][train_iteration][
"improved_output"
] = result.output
training_handler.save(training_data)
else:
self._printer.print(
@@ -487,3 +521,56 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.ask_for_human_input = False
return formatted_answer
def process_feedback(self, feedback: str) -> bool:
"""
Process feedback for the agent if feedback mode is enabled.
Parameters:
feedback (str): The feedback to process.
Returns:
bool: True if the feedback was processed successfully, False otherwise.
"""
if not self.allow_feedback:
self._logger.log("warning", "Feedback processing skipped (allow_feedback=False).", color="yellow")
return False
self._logger.log("info", f"Processing feedback: {feedback}", color="green")
# Add feedback to messages
self.messages.append(self._format_msg(f"Feedback: {feedback}"))
return True
def handle_conflict(self, other_agent: 'CrewAgentExecutor') -> bool:
"""
Handle conflict with another agent if conflict handling is enabled.
Parameters:
other_agent (CrewAgentExecutor): The other agent involved in the conflict.
Returns:
bool: True if the conflict was handled successfully, False otherwise.
"""
if not self.allow_conflict:
self._logger.log("warning", "Conflict handling skipped (allow_conflict=False).", color="yellow")
return False
self._logger.log("info", f"Handling conflict with agent: {other_agent.agent.role}", color="green")
return True
def process_iteration(self, result: Any) -> bool:
"""
Process iteration based on result if iteration mode is enabled.
Parameters:
result (Any): The result to iterate on.
Returns:
bool: True if the iteration was processed successfully, False otherwise.
"""
if not self.allow_iteration:
self._logger.log("warning", "Iteration processing skipped (allow_iteration=False).", color="yellow")
return False
self._logger.log("info", "Processing iteration on result.", color="green")
return True

View File

@@ -153,12 +153,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
login_response_json = login_response.json()
settings = Settings()
settings.tool_repository_username = login_response_json["credential"][
"username"
]
settings.tool_repository_password = login_response_json["credential"][
"password"
]
settings.tool_repository_username = login_response_json["credential"]["username"]
settings.tool_repository_password = login_response_json["credential"]["password"]
settings.dump()
console.print(
@@ -183,7 +179,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
capture_output=False,
env=self._build_env_with_credentials(repository_handle),
text=True,
check=True,
check=True
)
if add_package_result.stderr:
@@ -208,11 +204,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(settings.tool_repository_username or "")
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(settings.tool_repository_password or "")
return env

View File

@@ -25,7 +25,6 @@ from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -279,13 +278,6 @@ class Crew(BaseModel):
)
else:
self._user_memory = None
self.contextual_memory = ContextualMemory(
memory_config=self.memory_config,
stm=self._short_term_memory,
ltm=self._long_term_memory,
em=self._entity_memory,
um=self._user_memory,
)
return self
@model_validator(mode="after")

View File

@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
embedder_config: Optional[Dict[str, Any]] = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
embedder_config: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
@@ -49,13 +49,8 @@ class Knowledge(BaseModel):
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
results = self.storage.search(
query,
limit,

View File

@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: Optional[KnowledgeStorage] = Field(default=None)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
@@ -62,10 +62,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _save_documents(self):
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
self.storage.save(self.chunks)
def convert_to_path(self, path: Union[Path, str]) -> Path:
"""Convert a path to a Path object."""

View File

@@ -16,7 +16,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None)
@@ -46,7 +46,4 @@ class BaseKnowledgeSource(BaseModel, ABC):
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
self.storage.save(self.chunks)

View File

@@ -1,5 +1,4 @@
from typing import Any, Dict, Optional
from crewai.task import Task
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
@@ -11,7 +10,7 @@ class ContextualMemory:
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: Optional[UserMemory],
um: UserMemory,
):
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
@@ -22,7 +21,7 @@ class ContextualMemory:
self.em = em
self.um = um
def build_context_for_task(self, task: Task, context: str) -> str:
def build_context_for_task(self, task, context) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
@@ -40,7 +39,7 @@ class ContextualMemory:
context.append(self._fetch_user_context(query))
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query: str) -> str:
def _fetch_stm_context(self, query) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
@@ -54,7 +53,7 @@ class ContextualMemory:
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task: str) -> Optional[str]:
def _fetch_ltm_context(self, task) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -73,7 +72,7 @@ class ContextualMemory:
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query: str) -> str:
def _fetch_entity_context(self, query) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
@@ -95,8 +94,6 @@ class ContextualMemory:
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
if not self.um:
return ""
user_memories = self.um.search(query)
if not user_memories:
return ""

View File

@@ -11,7 +11,7 @@ class EntityMemory(Memory):
"""
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
if hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider")
else:
self.memory_provider = None

View File

@@ -15,17 +15,8 @@ class LongTermMemory(Memory):
"""
def __init__(self, storage=None, path=None):
"""Initialize long term memory.
Args:
storage: Optional custom storage instance
path: Optional custom path for storage location
Note:
If both storage and path are provided, storage takes precedence
"""
if not storage:
storage = LTMSQLiteStorage(storage_path=path) if path else LTMSQLiteStorage()
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"

View File

@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
"""
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
if hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider")
else:
self.memory_provider = None

View File

@@ -1,11 +1,5 @@
from abc import ABC, abstractmethod
from pathlib import Path
import os
from typing import Any, Dict, List, Optional, TypeVar
from abc import ABC, abstractmethod
from pathlib import Path
from crewai.utilities.paths import get_default_storage_path
from typing import Any, Dict, List, Optional
class BaseRAGStorage(ABC):
@@ -18,46 +12,17 @@ class BaseRAGStorage(ABC):
def __init__(
self,
type: str,
storage_path: Optional[Path] = None,
allow_reset: bool = True,
embedder_config: Optional[Any] = None,
crew: Any = None,
) -> None:
"""Initialize the BaseRAGStorage.
Args:
type: Type of storage being used
storage_path: Optional custom path for storage location
allow_reset: Whether storage can be reset
embedder_config: Optional configuration for the embedder
crew: Optional crew instance this storage belongs to
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
):
self.type = type
self.storage_path = storage_path if storage_path else get_default_storage_path('rag')
# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")
self.allow_reset = allow_reset
self.embedder_config = embedder_config
self.crew = crew
self.agents = self._initialize_agents()
def _initialize_agents(self) -> str:
"""Initialize agent identifiers for storage.
Returns:
str: Underscore-joined string of sanitized agent role names
"""
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
@@ -66,27 +31,12 @@ class BaseRAGStorage(ABC):
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names.
Args:
role: The agent role name to sanitize
Returns:
str: Sanitized role name safe for use in paths
"""
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the storage.
Args:
value: The value to store
metadata: Additional metadata to store with the value
Raises:
OSError: If there is an error writing to storage
"""
"""Save a value with metadata to the storage."""
pass
@abstractmethod
@@ -96,55 +46,25 @@ class BaseRAGStorage(ABC):
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
"""Search for entries in the storage.
Args:
query: The search query string
limit: Maximum number of results to return
filter: Optional filter criteria
score_threshold: Minimum similarity score threshold
Returns:
List[Dict[str, Any]]: List of matching entries with their metadata
"""
) -> List[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage.
Raises:
OSError: If there is an error clearing storage
PermissionError: If reset is not allowed
"""
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> List[float]:
"""Generate an embedding for the given text and metadata.
Args:
text: Text to generate embedding for
metadata: Optional metadata to include in embedding
Returns:
List[float]: Vector embedding of the text
Raises:
ValueError: If text is empty or invalid
"""
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self) -> None:
"""Initialize the vector db.
Raises:
OSError: If vector db initialization fails
"""
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):

View File

@@ -1,13 +1,11 @@
import json
import os
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional
from crewai.task import Task
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.paths import get_default_storage_path
from crewai.utilities.paths import db_storage_path
class KickoffTaskOutputsSQLiteStorage:
@@ -15,26 +13,10 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__(self, storage_path: Optional[Path] = None) -> None:
"""Initialize kickoff task outputs storage.
Args:
storage_path: Optional custom path for storage location
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
self.storage_path = storage_path if storage_path else get_default_storage_path('kickoff')
# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")
def __init__(
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
@@ -43,7 +25,7 @@ class KickoffTaskOutputsSQLiteStorage:
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -73,21 +55,9 @@ class KickoffTaskOutputsSQLiteStorage:
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] = {},
) -> None:
"""Add a task output to storage.
Args:
task: The task whose output is being stored
output: The output data from the task
task_index: Index of this task in the sequence
was_replayed: Whether this was from a replay
inputs: Optional input data that led to this output
Raises:
sqlite3.Error: If there is an error saving to database
"""
):
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -120,7 +90,7 @@ class KickoffTaskOutputsSQLiteStorage:
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
"""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
fields = []
@@ -149,7 +119,7 @@ class KickoffTaskOutputsSQLiteStorage:
def load(self) -> Optional[List[Dict[str, Any]]]:
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
@@ -185,7 +155,7 @@ class KickoffTaskOutputsSQLiteStorage:
Deletes all rows from the latest_kickoff_task_outputs table.
"""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()

View File

@@ -1,11 +1,9 @@
import json
import os
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from crewai.utilities import Printer
from crewai.utilities.paths import get_default_storage_path
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
@@ -13,26 +11,10 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""
def __init__(self, storage_path: Optional[Path] = None) -> None:
"""Initialize LTM SQLite storage.
Args:
storage_path: Optional custom path for storage location
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
self.storage_path = storage_path if storage_path else get_default_storage_path('ltm')
# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")
def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
@@ -41,7 +23,7 @@ class LTMSQLiteStorage:
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -69,20 +51,9 @@ class LTMSQLiteStorage:
datetime: str,
score: Union[int, float],
) -> None:
"""Save a memory entry to long-term memory.
Args:
task_description: Description of the task this memory relates to
metadata: Additional data to store with the memory
datetime: Timestamp for when this memory was created
score: Relevance score for this memory (higher is more relevant)
Raises:
sqlite3.Error: If there is an error saving to the database
"""
"""Saves data to the LTM table with error handling."""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
@@ -103,7 +74,7 @@ class LTMSQLiteStorage:
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
f"""
@@ -138,7 +109,7 @@ class LTMSQLiteStorage:
) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(str(self.storage_path)) as conn:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM long_term_memories")
conn.commit()

View File

@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
self.memory_type = type
self.crew = crew
self.memory_config = crew.memory_config if crew else None
self.memory_config = crew.memory_config
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
@@ -27,10 +27,9 @@ class Mem0Storage(Storage):
raise ValueError("User ID is required for user memory type")
# API key in memory config overrides the environment variable
if self.memory_config and self.memory_config.get("config"):
mem0_api_key = self.memory_config.get("config").get("api_key")
else:
mem0_api_key = os.getenv("MEM0_API_KEY")
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
"MEM0_API_KEY"
)
self.memory = MemoryClient(api_key=mem0_api_key)
def _sanitize_role(self, role: str) -> str:

View File

@@ -11,6 +11,7 @@ from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager
@@ -39,15 +40,9 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None
def __init__(
self,
type,
storage_path=None,
allow_reset=True,
embedder_config=None,
crew=None,
path=None,
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
super().__init__(type, storage_path, allow_reset, embedder_config, crew)
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
@@ -95,7 +90,7 @@ class RAGStorage(BaseRAGStorage):
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{self.storage_path}/{type}"
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
@@ -157,7 +152,7 @@ class RAGStorage(BaseRAGStorage):
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{self.storage_path}/{self.type}")
shutil.rmtree(f"{db_storage_path()}/{self.type}")
self.app = None
self.collection = None
except Exception as e:

View File

@@ -66,6 +66,7 @@ def cache_handler(func):
def crew(func) -> Callable[..., Crew]:
@wraps(func)
def wrapper(self, *args, **kwargs) -> Crew:
instantiated_tasks = []

View File

@@ -216,5 +216,5 @@ def CrewBase(cls: T) -> T:
# Include base class (qual)name in the wrapper class (qual)name.
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
return cast(T, WrappedClass)

View File

@@ -373,9 +373,7 @@ class Task(BaseModel):
content = (
json_output
if json_output
else pydantic_output.model_dump_json()
if pydantic_output
else result
else pydantic_output.model_dump_json() if pydantic_output else result
)
self._save_file(content)

View File

@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider", "")
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model")
@@ -38,13 +38,12 @@ class EmbeddingConfigurator:
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
embedding_function = self.embedding_functions.get(provider, None)
if not embedding_function:
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
return embedding_function(config, model_name)
return self.embedding_functions[provider](config, model_name)
@staticmethod
def _create_default_embedding_function():

View File

@@ -22,26 +22,3 @@ def get_project_directory_name():
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name
def get_default_storage_path(storage_type: str) -> Path:
"""Returns the default storage path for a given storage type.
Args:
storage_type: Type of storage ('ltm', 'kickoff', 'rag')
Returns:
Path: Default storage path for the specified type
Raises:
ValueError: If storage_type is not recognized
"""
base_path = db_storage_path()
if storage_type == 'ltm':
return base_path / 'latest_long_term_memories.db'
elif storage_type == 'kickoff':
return base_path / 'latest_kickoff_task_outputs.db'
elif storage_type == 'rag':
return base_path
else:
raise ValueError(f"Unknown storage type: {storage_type}")

View File

@@ -1625,3 +1625,127 @@ def test_agent_with_knowledge_sources():
# Assert that the agent provides the correct information
assert "red" in result.raw.lower()
def test_agent_with_feedback_conflict_iteration_params():
"""Test that the agent correctly handles the allow_feedback, allow_conflict, and allow_iteration parameters."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
allow_feedback=True,
allow_conflict=True,
allow_iteration=True,
)
assert agent.allow_feedback is True
assert agent.allow_conflict is True
assert agent.allow_iteration is True
# Create another agent with default values
default_agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
)
assert default_agent.allow_feedback is False
assert default_agent.allow_conflict is False
assert default_agent.allow_iteration is False
def test_agent_feedback_processing():
"""Test that the agent correctly processes feedback when allow_feedback is enabled."""
from unittest.mock import patch, MagicMock
# Create a mock CrewAgentExecutor
mock_executor = MagicMock()
mock_executor.allow_feedback = True
mock_executor.process_feedback.return_value = True
# Mock the create_agent_executor method at the module level
with patch('crewai.agent.Agent.create_agent_executor', return_value=mock_executor):
# Create an agent with allow_feedback=True
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
allow_feedback=True,
llm=MagicMock() # Mock LLM to avoid API calls
)
executor = agent.create_agent_executor()
assert executor.allow_feedback is True
result = executor.process_feedback("Test feedback")
assert result is True
executor.process_feedback.assert_called_once_with("Test feedback")
def test_agent_conflict_handling():
"""Test that the agent correctly handles conflicts when allow_conflict is enabled."""
from unittest.mock import patch, MagicMock
mock_executor1 = MagicMock()
mock_executor1.allow_conflict = True
mock_executor1.handle_conflict.return_value = True
mock_executor2 = MagicMock()
mock_executor2.allow_conflict = True
with patch('crewai.agent.Agent.create_agent_executor', return_value=mock_executor1):
# Create agents with allow_conflict=True
agent1 = Agent(
role="role1",
goal="goal1",
backstory="backstory1",
allow_conflict=True,
llm=MagicMock() # Mock LLM to avoid API calls
)
agent2 = Agent(
role="role2",
goal="goal2",
backstory="backstory2",
allow_conflict=True,
llm=MagicMock() # Mock LLM to avoid API calls
)
# Get the executors
executor1 = agent1.create_agent_executor()
executor2 = agent2.create_agent_executor()
assert executor1.allow_conflict is True
assert executor2.allow_conflict is True
result = executor1.handle_conflict(executor2)
assert result is True
executor1.handle_conflict.assert_called_once_with(executor2)
def test_agent_iteration_processing():
"""Test that the agent correctly processes iterations when allow_iteration is enabled."""
from unittest.mock import patch, MagicMock
# Create a mock CrewAgentExecutor
mock_executor = MagicMock()
mock_executor.allow_iteration = True
mock_executor.process_iteration.return_value = True
# Mock the create_agent_executor method at the module level
with patch('crewai.agent.Agent.create_agent_executor', return_value=mock_executor):
# Create an agent with allow_iteration=True
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
allow_iteration=True,
llm=MagicMock() # Mock LLM to avoid API calls
)
executor = agent.create_agent_executor()
assert executor.allow_iteration is True
result = executor.process_iteration("Test result")
assert result is True
executor.process_iteration.assert_called_once_with("Test result")

View File

@@ -28,10 +28,9 @@ def test_create_success(mock_subprocess):
with in_temp_dir():
tool_command = ToolCommand()
with (
patch.object(tool_command, "login") as mock_login,
patch("sys.stdout", new=StringIO()) as fake_out,
):
with patch.object(tool_command, "login") as mock_login, patch(
"sys.stdout", new=StringIO()
) as fake_out:
tool_command.create("test-tool")
output = fake_out.getvalue()
@@ -83,7 +82,7 @@ def test_install_success(mock_get, mock_subprocess_run):
capture_output=False,
text=True,
check=True,
env=unittest.mock.ANY,
env=unittest.mock.ANY
)
assert "Successfully installed sample-tool" in output

View File

@@ -1,83 +0,0 @@
import os
import tempfile
from pathlib import Path
import pytest
from unittest.mock import patch
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
from crewai.memory.storage.kickoff_task_outputs_storage import KickoffTaskOutputsSQLiteStorage
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import get_default_storage_path
class MockRAGStorage(BaseRAGStorage):
"""Mock implementation of BaseRAGStorage for testing."""
def _sanitize_role(self, role: str) -> str:
return role.lower()
def save(self, value, metadata):
pass
def search(self, query, limit=3, filter=None, score_threshold=0.35):
return []
def reset(self):
pass
def _generate_embedding(self, text, metadata=None):
return []
def _initialize_app(self):
pass
def test_default_storage_paths():
"""Test that default storage paths are created correctly."""
ltm_path = get_default_storage_path('ltm')
kickoff_path = get_default_storage_path('kickoff')
rag_path = get_default_storage_path('rag')
assert str(ltm_path).endswith('latest_long_term_memories.db')
assert str(kickoff_path).endswith('latest_kickoff_task_outputs.db')
assert isinstance(rag_path, Path)
def test_custom_storage_paths():
"""Test that custom storage paths are respected."""
with tempfile.TemporaryDirectory() as temp_dir:
custom_path = Path(temp_dir) / 'custom.db'
ltm = LTMSQLiteStorage(storage_path=custom_path)
assert ltm.storage_path == custom_path
kickoff = KickoffTaskOutputsSQLiteStorage(storage_path=custom_path)
assert kickoff.storage_path == custom_path
rag = MockRAGStorage('test', storage_path=custom_path)
assert rag.storage_path == custom_path
def test_directory_creation():
"""Test that storage directories are created automatically."""
with tempfile.TemporaryDirectory() as temp_dir:
test_dir = Path(temp_dir) / 'test_storage'
storage_path = test_dir / 'test.db'
assert not test_dir.exists()
LTMSQLiteStorage(storage_path=storage_path)
assert test_dir.exists()
def test_permission_error():
"""Test that permission errors are handled correctly."""
with tempfile.TemporaryDirectory() as temp_dir:
test_dir = Path(temp_dir) / 'readonly'
test_dir.mkdir()
os.chmod(test_dir, 0o444) # Read-only
storage_path = test_dir / 'test.db'
with pytest.raises((PermissionError, OSError)) as exc_info:
LTMSQLiteStorage(storage_path=storage_path)
# Verify that the error message mentions permission
assert "permission" in str(exc_info.value).lower()
def test_invalid_path():
"""Test that invalid paths raise appropriate errors."""
with pytest.raises(OSError):
# Try to create storage in a non-existent root directory
LTMSQLiteStorage(storage_path=Path('/nonexistent/dir/test.db'))

68
uv.lock generated
View File

@@ -1,18 +1,10 @@
version = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version < '3.11' and sys_platform == 'darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.12.4'",
]
[[package]]
@@ -308,7 +300,7 @@ name = "build"
version = "1.2.2.post1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "colorama", marker = "os_name == 'nt'" },
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
{ name = "packaging" },
{ name = "pyproject-hooks" },
@@ -543,7 +535,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@@ -650,6 +642,7 @@ tools = [
[package.dev-dependencies]
dev = [
{ name = "cairosvg" },
{ name = "crewai-tools" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
{ name = "mkdocs-material-extensions" },
@@ -703,6 +696,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "cairosvg", specifier = ">=2.7.1" },
{ name = "crewai-tools", specifier = ">=0.17.0" },
{ name = "mkdocs", specifier = ">=1.4.3" },
{ name = "mkdocs-material", specifier = ">=9.5.7" },
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
@@ -2468,7 +2462,7 @@ version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "ghp-import" },
{ name = "jinja2" },
{ name = "markdown" },
@@ -2649,7 +2643,7 @@ version = "2.10.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pygments" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
@@ -2896,7 +2890,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2923,9 +2917,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2936,7 +2930,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -3486,7 +3480,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@@ -5028,19 +5022,19 @@ dependencies = [
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "typing-extensions" },
]
wheels = [
@@ -5087,7 +5081,7 @@ name = "tqdm"
version = "4.66.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
wheels = [
@@ -5130,7 +5124,7 @@ version = "0.27.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "outcome" },
@@ -5161,7 +5155,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },