style: apply ruff formatting and import sorting

This commit is contained in:
Greyson Lalonde
2026-03-12 22:06:56 -04:00
parent d2a156f244
commit 0228445080
20 changed files with 197 additions and 119 deletions

View File

@@ -1,4 +1,3 @@
from datetime import datetime
import json
import os
import time
@@ -10,8 +9,8 @@ from pydantic import BaseModel, Field
from pydantic.types import StringConstraints
import requests
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
from crewai_tools.tools.brave_search_tool.base import _save_results_to_file
from crewai_tools.tools.brave_search_tool.schemas import WebSearchParams
load_dotenv()

View File

@@ -18,7 +18,6 @@ class MergeAgentHandlerToolError(Exception):
"""Base exception for Merge Agent Handler tool errors."""
class MergeAgentHandlerTool(BaseTool):
"""
Wrapper for Merge Agent Handler tools.
@@ -174,7 +173,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tool = MergeAgentHandlerTool.from_tool_name(
... tool_name="linear__create_issue",
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa"
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... )
"""
# Create an empty args schema model (proper BaseModel subclass)
@@ -210,7 +209,10 @@ class MergeAgentHandlerTool(BaseTool):
if "parameters" in tool_schema:
try:
params = tool_schema["parameters"]
if params.get("type") == "object" and "properties" in params:
if (
params.get("type") == "object"
and "properties" in params
):
# Build field definitions for Pydantic
fields = {}
properties = params["properties"]
@@ -298,7 +300,7 @@ class MergeAgentHandlerTool(BaseTool):
>>> tools = MergeAgentHandlerTool.from_tool_pack(
... tool_pack_id="134e0111-0f67-44f6-98f0-597000290bb3",
... registered_user_id="91b2b905-e866-40c8-8be2-efe53827a0aa",
... tool_names=["linear__create_issue", "linear__get_issues"]
... tool_names=["linear__create_issue", "linear__get_issues"],
... )
"""
# Create a temporary instance to fetch the tool list

View File

@@ -110,11 +110,13 @@ class QdrantVectorSearchTool(BaseTool):
self.custom_embedding_fn(query)
if self.custom_embedding_fn
else (
lambda: __import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
lambda: (
__import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)
)()
)
results = self.client.query_points(

View File

@@ -137,7 +137,9 @@ class StagehandTool(BaseTool):
- 'observe': For finding elements in a specific area
"""
args_schema: type[BaseModel] = StagehandToolSchema
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
package_dependencies: list[str] = Field(
default_factory=lambda: ["stagehand<=0.5.9"]
)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(

View File

@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent,
)
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
if parse_error is not None:
return parse_error

View File

@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
storage = (
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
)
embedder = None
if embedder_config is not None:
from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(embedder_config)
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
self._memory = (
Memory(storage=storage, embedder=embedder)
if embedder
else Memory(storage=storage)
)
except Exception as e:
self._init_error = str(e)
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
if len(record.content) > 80
else record.content
)
label = (
f"{date_str} "
f"[bold]{record.importance:.1f}[/] "
f"{preview}"
)
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
option_list.add_option(label)
def _populate_recall_list(self) -> None:
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
else m.record.content
)
label = (
f"[bold]\\[{m.score:.2f}][/] "
f"{preview} "
f"[dim]scope={m.record.scope}[/]"
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
)
option_list.add_option(label)
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
lines.append(
f"[dim]Created:[/] "
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
)
lines.append(
f"[dim]Last accessed:[/] "
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
panel = self.query_one("#info-panel", Static)
panel.loading = True
try:
scope = (
self._selected_scope
if self._selected_scope != "/"
else None
)
scope = self._selected_scope if self._selected_scope != "/" else None
loop = asyncio.get_event_loop()
matches = await loop.run_in_executor(
None,
lambda: self._memory.recall(
query, scope=scope, limit=10, depth="deep"
),
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
)
self._recall_matches = matches or []
self._view_mode = "recall"

View File

@@ -95,9 +95,7 @@ def reset_memories_command(
continue
if memory:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Memory has been reset."
)
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for search_path in search_paths:
for root, dirs, files in os.walk(search_path):
dirs[:] = [
d
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
]
if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path)
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
try:
if flow_instance := get_flow_instance(
module_attr
):
if flow_instance := get_flow_instance(module_attr):
flow_instances.append(flow_instance)
except Exception: # noqa: S112
continue

View File

@@ -1410,9 +1410,7 @@ class Crew(FlowTrackable, BaseModel):
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return tools
def _add_memory_tools(
self, tools: list[BaseTool], memory: Any
) -> list[BaseTool]:
def _add_memory_tools(self, tools: list[BaseTool], memory: Any) -> list[BaseTool]:
"""Add recall and remember tools when memory is available.
Args:

View File

@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
max_workers = min(8, len(runnable_tool_calls))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_idx = {
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
pool.submit(
contextvars.copy_context().run,
self._execute_single_native_tool_call,
tool_call,
): idx
for idx, tool_call in enumerate(runnable_tool_calls)
}
ordered_results: list[dict[str, Any] | None] = [None] * len(

View File

@@ -34,6 +34,7 @@ class ConsoleProvider:
```python
from crewai.flow.async_feedback import ConsoleProvider
@human_feedback(
message="Review this:",
provider=ConsoleProvider(),
@@ -46,6 +47,7 @@ class ConsoleProvider:
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):

View File

@@ -497,7 +497,9 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._list)
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
def index(
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
) -> int: # type: ignore[override]
if stop is None:
return self._list.index(value, start)
return self._list.index(value, start, stop)

View File

@@ -188,7 +188,7 @@ def human_feedback(
metadata: dict[str, Any] | None = None,
provider: HumanFeedbackProvider | None = None,
learn: bool = False,
learn_source: str = "hitl"
learn_source: str = "hitl",
) -> Callable[[F], F]:
"""Decorator for Flow methods that require human feedback.
@@ -328,9 +328,7 @@ def human_feedback(
"""Recall past HITL lessons and use LLM to pre-review the output."""
try:
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
matches = flow_instance.memory.recall(
query, source=learn_source
)
matches = flow_instance.memory.recall(query, source=learn_source)
if not matches:
return method_output
@@ -341,7 +339,10 @@ def human_feedback(
lessons=lessons,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_pre_review_system"),
},
{"role": "user", "content": prompt},
]
if getattr(llm_inst, "supports_function_calling", lambda: False)():
@@ -366,7 +367,10 @@ def human_feedback(
feedback=raw_feedback,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_distill_system"),
},
{"role": "user", "content": prompt},
]
@@ -487,7 +491,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -507,7 +515,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -534,7 +546,7 @@ def human_feedback(
metadata=metadata,
provider=provider,
learn=learn,
learn_source=learn_source
learn_source=learn_source,
)
wrapper.__is_flow_method__ = True

View File

@@ -22,10 +22,10 @@ from crewai.mcp.config import (
MCPServerSSE,
MCPServerStdio,
)
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
from crewai.utilities.string_utils import sanitize_tool_name
if TYPE_CHECKING:
@@ -227,7 +227,9 @@ class MCPToolResolver:
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
sanitized_specific_tool = (
sanitize_tool_name(specific_tool) if specific_tool else None
)
try:
tool_schemas = self._get_mcp_tool_schemas(server_params)

View File

@@ -308,7 +308,9 @@ def analyze_for_save(
return MemoryAnalysis.model_validate(response)
except Exception as e:
_logger.warning(
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
"Memory save analysis failed, using defaults: %s",
e,
exc_info=False,
)
return _SAVE_DEFAULTS
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
return ConsolidationPlan.model_validate(response)
except Exception as e:
_logger.warning(
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
"Consolidation analysis failed, defaulting to insert: %s",
e,
exc_info=False,
)
return _CONSOLIDATION_DEFAULT

View File

@@ -164,7 +164,11 @@ class EncodingFlow(Flow[EncodingState]):
def parallel_find_similar(self) -> None:
"""Search storage for similar records, concurrently for all active items."""
items = list(self.state.items)
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
active = [
(i, item)
for i, item in enumerate(items)
if not item.dropped and item.embedding
]
if not active:
return
@@ -186,7 +190,9 @@ class EncodingFlow(Flow[EncodingState]):
item.top_similarity = float(raw[0][1]) if raw else 0.0
else:
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
futures = [
(i, item, pool.submit(_search_one, item)) for i, item in active
]
for _, item, future in futures:
raw = future.result()
item.similar_records = [r for r, _ in raw]
@@ -251,23 +257,33 @@ class EncodingFlow(Flow[EncodingState]):
self._apply_defaults(item)
consol_futures[i] = pool.submit(
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
elif not fields_provided and not has_similar:
# Group C: field resolution only
save_futures[i] = pool.submit(
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
else:
# Group D: both in parallel
save_futures[i] = pool.submit(
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
consol_futures[i] = pool.submit(
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
# Collect field-resolution results
@@ -339,7 +355,9 @@ class EncodingFlow(Flow[EncodingState]):
# similar_records overlap). Collect one action per record_id, first wins.
# Also build a map from record_id to the original MemoryRecord for updates.
dedup_deletes: set[str] = set() # record_ids to delete
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
dedup_updates: dict[
str, tuple[int, str]
] = {} # record_id -> (item_idx, new_content)
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
for i, item in enumerate(items):
@@ -350,13 +368,24 @@ class EncodingFlow(Flow[EncodingState]):
all_similar[r.id] = r
for action in item.plan.actions:
rid = action.record_id
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
if (
action.action == "delete"
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_deletes.add(rid)
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
elif (
action.action == "update"
and action.new_content
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_updates[rid] = (i, action.new_content)
# --- Batch re-embed all update contents in ONE call ---
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
update_list = list(
dedup_updates.items()
) # [(record_id, (item_idx, new_content)), ...]
update_embeddings: list[list[float]] = []
if update_list:
update_contents = [content for _, (_, content) in update_list]
@@ -377,16 +406,21 @@ class EncodingFlow(Flow[EncodingState]):
if item.dropped or item.plan is None:
continue
if item.plan.insert_new:
to_insert.append((i, MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
)))
to_insert.append(
(
i,
MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
),
)
)
# All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant

View File

@@ -103,13 +103,12 @@ class RecallFlow(Flow[RecallState]):
)
# Post-filter by time cutoff
if self.state.time_cutoff and raw:
raw = [
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
]
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
# Privacy filter
if not self.state.include_private and raw:
raw = [
(r, s) for r, s in raw
(r, s)
for r, s in raw
if not r.private or r.source == self.state.source
]
return scope, raw
@@ -130,16 +129,17 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
else:
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
futures = {
pool.submit(_search_one, emb, sc): (emb, sc)
for emb, sc in tasks
pool.submit(_search_one, emb, sc): (emb, sc) for emb, sc in tasks
}
for future in as_completed(futures):
scope, results = future.result()
@@ -147,16 +147,16 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
self.state.chunk_findings = findings
self.state.confidence = max(
(f["top_score"] for f in findings), default=0.0
)
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
return findings
# ------------------------------------------------------------------
@@ -210,12 +210,16 @@ class RecallFlow(Flow[RecallState]):
# Parse time_filter into a datetime cutoff
if analysis.time_filter:
try:
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
self.state.time_cutoff = datetime.fromisoformat(
analysis.time_filter
)
except ValueError:
pass
# Batch-embed all sub-queries in ONE call
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
queries = (
analysis.recall_queries if analysis.recall_queries else [self.state.query]
)
queries = queries[:3]
embeddings = embed_texts(self._embedder, queries)
pairs: list[tuple[str, list[float]]] = [
@@ -296,17 +300,21 @@ class RecallFlow(Flow[RecallState]):
response = self._llm.call([{"role": "user", "content": prompt}])
if isinstance(response, str) and "missing" in response.lower():
self.state.evidence_gaps.append(response[:200])
enhanced.append({
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
}
)
except Exception:
enhanced.append({
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
}
)
self.state.chunk_findings = enhanced
return enhanced

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import asyncio
import contextvars
from concurrent.futures import Future
import contextvars
from copy import copy as shallow_copy
import datetime
from hashlib import md5

View File

@@ -100,7 +100,12 @@ class I18N(BaseModel):
def retrieve(
self,
kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
"slices",
"errors",
"tools",
"reasoning",
"hierarchical_manager_agent",
"memory",
],
key: str,
) -> str:

View File

@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
A tuple of (type, Field) for use with create_model.
"""
type_ = _json_schema_to_pydantic_type(
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
json_schema,
root_schema,
name_=name.title(),
enrich_descriptions=enrich_descriptions,
)
is_required = name in required
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
if ref:
ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type(
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
ref_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
enum_values = json_schema.get("enum")
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
if all_of_schemas:
if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_,
all_of_schemas[0],
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type(
merged, root_schema, name_=name_,
merged,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
items_schema = json_schema.get("items")
if items_schema:
item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_,
items_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
return list[item_type] # type: ignore[valid-type]
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
if json_schema_.get("title") is None:
json_schema_["title"] = name_ or "DynamicModel"
return create_model_from_schema(
json_schema_, root_schema=root_schema,
json_schema_,
root_schema=root_schema,
enrich_descriptions=enrich_descriptions,
)
return dict