chore: linter on previous files

This commit is contained in:
Greyson Lalonde
2026-03-05 19:56:12 -05:00
parent 218084d625
commit 4627c345c9
17 changed files with 244 additions and 131 deletions

View File

@@ -32,10 +32,7 @@ class CrewAgentExecutorMixin:
) )
if memory is None or not self.task or getattr(memory, "_read_only", False): if memory is None or not self.task or getattr(memory, "_read_only", False):
return return
if ( if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
in output.text
):
return return
try: try:
raw = ( raw = (
@@ -48,6 +45,4 @@ class CrewAgentExecutorMixin:
if extracted: if extracted:
memory.remember_many(extracted, agent_role=self.agent.role) memory.remember_many(extracted, agent_role=self.agent.role)
except Exception as e: except Exception as e:
self.agent._logger.log( self.agent._logger.log("error", f"Failed to save to memory: {e}")
"error", f"Failed to save to memory: {e}"
)

View File

@@ -8,8 +8,8 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable
import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool) args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
if parse_error is not None: if parse_error is not None:
return parse_error return parse_error

View File

@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
@crewai.command() @crewai.command()
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY") @click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
@click.option( @click.option(
"-l", "--long", is_flag=True, hidden=True, "-l",
"--long",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory", help="[Deprecated: use --memory] Reset memory",
) )
@click.option( @click.option(
"-s", "--short", is_flag=True, hidden=True, "-s",
"--short",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory", help="[Deprecated: use --memory] Reset memory",
) )
@click.option( @click.option(
"-e", "--entities", is_flag=True, hidden=True, "-e",
"--entities",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory", help="[Deprecated: use --memory] Reset memory",
) )
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage") @click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@@ -218,7 +227,13 @@ def reset_memories(
# Treat legacy flags as --memory with a deprecation warning # Treat legacy flags as --memory with a deprecation warning
if long or short or entities: if long or short or entities:
legacy_used = [ legacy_used = [
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v f
for f, v in [
("--long", long),
("--short", short),
("--entities", entities),
]
if v
] ]
click.echo( click.echo(
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} " f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
@@ -238,9 +253,7 @@ def reset_memories(
"Please specify at least one memory type to reset using the appropriate flags." "Please specify at least one memory type to reset using the appropriate flags."
) )
return return
reset_memories_command( reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
memory, knowledge, agent_knowledge, kickoff_outputs, all
)
except Exception as e: except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True) click.echo(f"An error occurred while resetting memories: {e}", err=True)

View File

@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
from crewai.memory.storage.lancedb_storage import LanceDBStorage from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory from crewai.memory.unified_memory import Memory
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage() storage = (
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
)
embedder = None embedder = None
if embedder_config is not None: if embedder_config is not None:
from crewai.rag.embeddings.factory import build_embedder from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(embedder_config) embedder = build_embedder(embedder_config)
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage) self._memory = (
Memory(storage=storage, embedder=embedder)
if embedder
else Memory(storage=storage)
)
except Exception as e: except Exception as e:
self._init_error = str(e) self._init_error = str(e)
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
if len(record.content) > 80 if len(record.content) > 80
else record.content else record.content
) )
label = ( label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
f"{date_str} "
f"[bold]{record.importance:.1f}[/] "
f"{preview}"
)
option_list.add_option(label) option_list.add_option(label)
def _populate_recall_list(self) -> None: def _populate_recall_list(self) -> None:
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
else m.record.content else m.record.content
) )
label = ( label = (
f"[bold]\\[{m.score:.2f}][/] " f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
f"{preview} "
f"[dim]scope={m.record.scope}[/]"
) )
option_list.add_option(label) option_list.add_option(label)
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]") lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]") lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
lines.append( lines.append(
f"[dim]Created:[/] " f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
) )
lines.append( lines.append(
f"[dim]Last accessed:[/] " f"[dim]Last accessed:[/] "
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
panel = self.query_one("#info-panel", Static) panel = self.query_one("#info-panel", Static)
panel.loading = True panel.loading = True
try: try:
scope = ( scope = self._selected_scope if self._selected_scope != "/" else None
self._selected_scope
if self._selected_scope != "/"
else None
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
matches = await loop.run_in_executor( matches = await loop.run_in_executor(
None, None,
lambda: self._memory.recall( lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
query, scope=scope, limit=10, depth="deep"
),
) )
self._recall_matches = matches or [] self._recall_matches = matches or []
self._view_mode = "recall" self._view_mode = "recall"

View File

@@ -95,9 +95,7 @@ def reset_memories_command(
continue continue
if memory: if memory:
_reset_flow_memory(flow) _reset_flow_memory(flow)
click.echo( click.echo(f"[Flow ({flow_name})] Memory has been reset.")
f"[Flow ({flow_name})] Memory has been reset."
)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True) click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for search_path in search_paths: for search_path in search_paths:
for root, dirs, files in os.walk(search_path): for root, dirs, files in os.walk(search_path):
dirs[:] = [ dirs[:] = [
d d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
] ]
if flow_path in files and "cli/templates" not in root: if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path) file_os_path = os.path.join(root, flow_path)
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for attr_name in dir(module): for attr_name in dir(module):
module_attr = getattr(module, attr_name) module_attr = getattr(module, attr_name)
try: try:
if flow_instance := get_flow_instance( if flow_instance := get_flow_instance(module_attr):
module_attr
):
flow_instances.append(flow_instance) flow_instances.append(flow_instance)
except Exception: # noqa: S112 except Exception: # noqa: S112
continue continue

View File

@@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextvars
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from datetime import datetime from datetime import datetime
import inspect import inspect
import json import json
@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
max_workers = min(8, len(runnable_tool_calls)) max_workers = min(8, len(runnable_tool_calls))
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_idx = { future_to_idx = {
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx pool.submit(
contextvars.copy_context().run,
self._execute_single_native_tool_call,
tool_call,
): idx
for idx, tool_call in enumerate(runnable_tool_calls) for idx, tool_call in enumerate(runnable_tool_calls)
} }
ordered_results: list[dict[str, Any] | None] = [None] * len( ordered_results: list[dict[str, Any] | None] = [None] * len(

View File

@@ -34,6 +34,7 @@ class ConsoleProvider:
```python ```python
from crewai.flow.async_feedback import ConsoleProvider from crewai.flow.async_feedback import ConsoleProvider
@human_feedback( @human_feedback(
message="Review this:", message="Review this:",
provider=ConsoleProvider(), provider=ConsoleProvider(),
@@ -46,6 +47,7 @@ class ConsoleProvider:
```python ```python
from crewai.flow import Flow, start from crewai.flow import Flow, start
class MyFlow(Flow): class MyFlow(Flow):
@start() @start()
def gather_info(self): def gather_info(self):

View File

@@ -188,7 +188,7 @@ def human_feedback(
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
provider: HumanFeedbackProvider | None = None, provider: HumanFeedbackProvider | None = None,
learn: bool = False, learn: bool = False,
learn_source: str = "hitl" learn_source: str = "hitl",
) -> Callable[[F], F]: ) -> Callable[[F], F]:
"""Decorator for Flow methods that require human feedback. """Decorator for Flow methods that require human feedback.
@@ -328,9 +328,7 @@ def human_feedback(
"""Recall past HITL lessons and use LLM to pre-review the output.""" """Recall past HITL lessons and use LLM to pre-review the output."""
try: try:
query = f"human feedback lessons for {func.__name__}: {method_output!s}" query = f"human feedback lessons for {func.__name__}: {method_output!s}"
matches = flow_instance.memory.recall( matches = flow_instance.memory.recall(query, source=learn_source)
query, source=learn_source
)
if not matches: if not matches:
return method_output return method_output
@@ -341,7 +339,10 @@ def human_feedback(
lessons=lessons, lessons=lessons,
) )
messages = [ messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")}, {
"role": "system",
"content": _get_hitl_prompt("hitl_pre_review_system"),
},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
if getattr(llm_inst, "supports_function_calling", lambda: False)(): if getattr(llm_inst, "supports_function_calling", lambda: False)():
@@ -366,7 +367,10 @@ def human_feedback(
feedback=raw_feedback, feedback=raw_feedback,
) )
messages = [ messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")}, {
"role": "system",
"content": _get_hitl_prompt("hitl_distill_system"),
},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
@@ -487,7 +491,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback) result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory # Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip(): if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback) _distill_and_store_lessons(self, method_output, raw_feedback)
return result return result
@@ -507,7 +515,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback) result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory # Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip(): if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback) _distill_and_store_lessons(self, method_output, raw_feedback)
return result return result
@@ -534,7 +546,7 @@ def human_feedback(
metadata=metadata, metadata=metadata,
provider=provider, provider=provider,
learn=learn, learn=learn,
learn_source=learn_source learn_source=learn_source,
) )
wrapper.__is_flow_method__ = True wrapper.__is_flow_method__ = True

View File

@@ -308,7 +308,9 @@ def analyze_for_save(
return MemoryAnalysis.model_validate(response) return MemoryAnalysis.model_validate(response)
except Exception as e: except Exception as e:
_logger.warning( _logger.warning(
"Memory save analysis failed, using defaults: %s", e, exc_info=False, "Memory save analysis failed, using defaults: %s",
e,
exc_info=False,
) )
return _SAVE_DEFAULTS return _SAVE_DEFAULTS
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
return ConsolidationPlan.model_validate(response) return ConsolidationPlan.model_validate(response)
except Exception as e: except Exception as e:
_logger.warning( _logger.warning(
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False, "Consolidation analysis failed, defaulting to insert: %s",
e,
exc_info=False,
) )
return _CONSOLIDATION_DEFAULT return _CONSOLIDATION_DEFAULT

View File

@@ -164,7 +164,11 @@ class EncodingFlow(Flow[EncodingState]):
def parallel_find_similar(self) -> None: def parallel_find_similar(self) -> None:
"""Search storage for similar records, concurrently for all active items.""" """Search storage for similar records, concurrently for all active items."""
items = list(self.state.items) items = list(self.state.items)
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding] active = [
(i, item)
for i, item in enumerate(items)
if not item.dropped and item.embedding
]
if not active: if not active:
return return
@@ -186,7 +190,9 @@ class EncodingFlow(Flow[EncodingState]):
item.top_similarity = float(raw[0][1]) if raw else 0.0 item.top_similarity = float(raw[0][1]) if raw else 0.0
else: else:
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool: with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active] futures = [
(i, item, pool.submit(_search_one, item)) for i, item in active
]
for _, item, future in futures: for _, item, future in futures:
raw = future.result() raw = future.result()
item.similar_records = [r for r, _ in raw] item.similar_records = [r for r, _ in raw]
@@ -251,23 +257,33 @@ class EncodingFlow(Flow[EncodingState]):
self._apply_defaults(item) self._apply_defaults(item)
consol_futures[i] = pool.submit( consol_futures[i] = pool.submit(
analyze_for_consolidation, analyze_for_consolidation,
item.content, list(item.similar_records), self._llm, item.content,
list(item.similar_records),
self._llm,
) )
elif not fields_provided and not has_similar: elif not fields_provided and not has_similar:
# Group C: field resolution only # Group C: field resolution only
save_futures[i] = pool.submit( save_futures[i] = pool.submit(
analyze_for_save, analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm, item.content,
existing_scopes,
existing_categories,
self._llm,
) )
else: else:
# Group D: both in parallel # Group D: both in parallel
save_futures[i] = pool.submit( save_futures[i] = pool.submit(
analyze_for_save, analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm, item.content,
existing_scopes,
existing_categories,
self._llm,
) )
consol_futures[i] = pool.submit( consol_futures[i] = pool.submit(
analyze_for_consolidation, analyze_for_consolidation,
item.content, list(item.similar_records), self._llm, item.content,
list(item.similar_records),
self._llm,
) )
# Collect field-resolution results # Collect field-resolution results
@@ -339,7 +355,9 @@ class EncodingFlow(Flow[EncodingState]):
# similar_records overlap). Collect one action per record_id, first wins. # similar_records overlap). Collect one action per record_id, first wins.
# Also build a map from record_id to the original MemoryRecord for updates. # Also build a map from record_id to the original MemoryRecord for updates.
dedup_deletes: set[str] = set() # record_ids to delete dedup_deletes: set[str] = set() # record_ids to delete
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content) dedup_updates: dict[
str, tuple[int, str]
] = {} # record_id -> (item_idx, new_content)
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
for i, item in enumerate(items): for i, item in enumerate(items):
@@ -350,13 +368,24 @@ class EncodingFlow(Flow[EncodingState]):
all_similar[r.id] = r all_similar[r.id] = r
for action in item.plan.actions: for action in item.plan.actions:
rid = action.record_id rid = action.record_id
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates: if (
action.action == "delete"
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_deletes.add(rid) dedup_deletes.add(rid)
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates: elif (
action.action == "update"
and action.new_content
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_updates[rid] = (i, action.new_content) dedup_updates[rid] = (i, action.new_content)
# --- Batch re-embed all update contents in ONE call --- # --- Batch re-embed all update contents in ONE call ---
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...] update_list = list(
dedup_updates.items()
) # [(record_id, (item_idx, new_content)), ...]
update_embeddings: list[list[float]] = [] update_embeddings: list[list[float]] = []
if update_list: if update_list:
update_contents = [content for _, (_, content) in update_list] update_contents = [content for _, (_, content) in update_list]
@@ -377,16 +406,21 @@ class EncodingFlow(Flow[EncodingState]):
if item.dropped or item.plan is None: if item.dropped or item.plan is None:
continue continue
if item.plan.insert_new: if item.plan.insert_new:
to_insert.append((i, MemoryRecord( to_insert.append(
content=item.content, (
scope=item.resolved_scope, i,
categories=item.resolved_categories, MemoryRecord(
metadata=item.resolved_metadata, content=item.content,
importance=item.resolved_importance, scope=item.resolved_scope,
embedding=item.embedding if item.embedding else None, categories=item.resolved_categories,
source=item.resolved_source, metadata=item.resolved_metadata,
private=item.resolved_private, importance=item.resolved_importance,
))) embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
),
)
)
# All storage mutations under one lock so no other pipeline can # All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant # interleave and cause version conflicts. The lock is reentrant

View File

@@ -249,9 +249,17 @@ class MemorySlice:
total_records += inf.record_count total_records += inf.record_count
all_categories.update(inf.categories) all_categories.update(inf.categories)
if inf.oldest_record: if inf.oldest_record:
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record) oldest = (
inf.oldest_record
if oldest is None
else min(oldest, inf.oldest_record)
)
if inf.newest_record: if inf.newest_record:
newest = inf.newest_record if newest is None else max(newest, inf.newest_record) newest = (
inf.newest_record
if newest is None
else max(newest, inf.newest_record)
)
children.extend(inf.child_scopes) children.extend(inf.child_scopes)
return ScopeInfo( return ScopeInfo(
path=path, path=path,

View File

@@ -103,13 +103,12 @@ class RecallFlow(Flow[RecallState]):
) )
# Post-filter by time cutoff # Post-filter by time cutoff
if self.state.time_cutoff and raw: if self.state.time_cutoff and raw:
raw = [ raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
]
# Privacy filter # Privacy filter
if not self.state.include_private and raw: if not self.state.include_private and raw:
raw = [ raw = [
(r, s) for r, s in raw (r, s)
for r, s in raw
if not r.private or r.source == self.state.source if not r.private or r.source == self.state.source
] ]
return scope, raw return scope, raw
@@ -130,16 +129,17 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score( top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config results[0][0], results[0][1], self._config
) )
findings.append({ findings.append(
"scope": scope, {
"results": results, "scope": scope,
"top_score": top_composite, "results": results,
}) "top_score": top_composite,
}
)
else: else:
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool: with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
futures = { futures = {
pool.submit(_search_one, emb, sc): (emb, sc) pool.submit(_search_one, emb, sc): (emb, sc) for emb, sc in tasks
for emb, sc in tasks
} }
for future in as_completed(futures): for future in as_completed(futures):
scope, results = future.result() scope, results = future.result()
@@ -147,16 +147,16 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score( top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config results[0][0], results[0][1], self._config
) )
findings.append({ findings.append(
"scope": scope, {
"results": results, "scope": scope,
"top_score": top_composite, "results": results,
}) "top_score": top_composite,
}
)
self.state.chunk_findings = findings self.state.chunk_findings = findings
self.state.confidence = max( self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
(f["top_score"] for f in findings), default=0.0
)
return findings return findings
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -210,12 +210,16 @@ class RecallFlow(Flow[RecallState]):
# Parse time_filter into a datetime cutoff # Parse time_filter into a datetime cutoff
if analysis.time_filter: if analysis.time_filter:
try: try:
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter) self.state.time_cutoff = datetime.fromisoformat(
analysis.time_filter
)
except ValueError: except ValueError:
pass pass
# Batch-embed all sub-queries in ONE call # Batch-embed all sub-queries in ONE call
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query] queries = (
analysis.recall_queries if analysis.recall_queries else [self.state.query]
)
queries = queries[:3] queries = queries[:3]
embeddings = embed_texts(self._embedder, queries) embeddings = embed_texts(self._embedder, queries)
pairs: list[tuple[str, list[float]]] = [ pairs: list[tuple[str, list[float]]] = [
@@ -296,17 +300,21 @@ class RecallFlow(Flow[RecallState]):
response = self._llm.call([{"role": "user", "content": prompt}]) response = self._llm.call([{"role": "user", "content": prompt}])
if isinstance(response, str) and "missing" in response.lower(): if isinstance(response, str) and "missing" in response.lower():
self.state.evidence_gaps.append(response[:200]) self.state.evidence_gaps.append(response[:200])
enhanced.append({ enhanced.append(
"scope": finding["scope"], {
"extraction": response, "scope": finding["scope"],
"results": finding["results"], "extraction": response,
}) "results": finding["results"],
}
)
except Exception: except Exception:
enhanced.append({ enhanced.append(
"scope": finding["scope"], {
"extraction": "", "scope": finding["scope"],
"results": finding["results"], "extraction": "",
}) "results": finding["results"],
}
)
self.state.chunk_findings = enhanced self.state.chunk_findings = enhanced
return enhanced return enhanced

View File

@@ -90,6 +90,7 @@ class LanceDBStorage:
# Raise it proactively so scans on large tables never hit OS error 24. # Raise it proactively so scans on large tables never hit OS error 24.
try: try:
import resource import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft < 4096: if soft < 4096:
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
@@ -110,7 +111,9 @@ class LanceDBStorage:
# If no table exists yet, defer creation until the first save so the # If no table exists yet, defer creation until the first save so the
# dimension can be auto-detected from the embedder's actual output. # dimension can be auto-detected from the embedder's actual output.
try: try:
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name) self._table: lancedb.table.Table | None = self._db.open_table(
self._table_name
)
self._vector_dim: int = self._infer_dim_from_table(self._table) self._vector_dim: int = self._infer_dim_from_table(self._table)
# Best-effort: create the scope index if it doesn't exist yet. # Best-effort: create the scope index if it doesn't exist yet.
self._ensure_scope_index() self._ensure_scope_index()
@@ -171,7 +174,10 @@ class LanceDBStorage:
raise raise
_logger.debug( _logger.debug(
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs", "LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
op, attempt + 1, _MAX_RETRIES, delay, op,
attempt + 1,
_MAX_RETRIES,
delay,
) )
# Refresh table to pick up the latest version before retrying. # Refresh table to pick up the latest version before retrying.
# The next getattr(self._table, op) will use the fresh table. # The next getattr(self._table, op) will use the fresh table.
@@ -280,7 +286,9 @@ class LanceDBStorage:
"last_accessed": record.last_accessed.isoformat(), "last_accessed": record.last_accessed.isoformat(),
"source": record.source or "", "source": record.source or "",
"private": record.private, "private": record.private,
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim, "vector": record.embedding
if record.embedding
else [0.0] * self._vector_dim,
} }
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord: def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
@@ -296,7 +304,9 @@ class LanceDBStorage:
id=str(row["id"]), id=str(row["id"]),
content=str(row["content"]), content=str(row["content"]),
scope=str(row["scope"]), scope=str(row["scope"]),
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [], categories=json.loads(row["categories_str"])
if row.get("categories_str")
else [],
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {}, metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
importance=float(row.get("importance", 0.5)), importance=float(row.get("importance", 0.5)),
created_at=_parse_dt(row.get("created_at")), created_at=_parse_dt(row.get("created_at")),
@@ -390,13 +400,17 @@ class LanceDBStorage:
prefix = scope_prefix.rstrip("/") prefix = scope_prefix.rstrip("/")
like_val = prefix + "%" like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'") query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list() results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
out: list[tuple[MemoryRecord, float]] = [] out: list[tuple[MemoryRecord, float]] = []
for row in results: for row in results:
record = self._row_to_record(row) record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories): if categories and not any(c in record.categories for c in categories):
continue continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()): if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue continue
distance = row.get("_distance", 0.0) distance = row.get("_distance", 0.0)
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0 score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
@@ -427,9 +441,13 @@ class LanceDBStorage:
to_delete: list[str] = [] to_delete: list[str] = []
for row in rows: for row in rows:
record = self._row_to_record(row) record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories): if categories and not any(
c in record.categories for c in categories
):
continue continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()): if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue continue
if older_than and record.created_at >= older_than: if older_than and record.created_at >= older_than:
continue continue
@@ -528,7 +546,7 @@ class LanceDBStorage:
for row in rows: for row in rows:
sc = str(row.get("scope", "")) sc = str(row.get("scope", ""))
if child_prefix and sc.startswith(child_prefix): if child_prefix and sc.startswith(child_prefix):
rest = sc[len(child_prefix):] rest = sc[len(child_prefix) :]
first_component = rest.split("/", 1)[0] first_component = rest.split("/", 1)[0]
if first_component: if first_component:
children.add(child_prefix + first_component) children.add(child_prefix + first_component)
@@ -539,7 +557,11 @@ class LanceDBStorage:
pass pass
created = row.get("created_at") created = row.get("created_at")
if created: if created:
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created dt = (
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
if isinstance(created, str)
else created
)
if isinstance(dt, datetime): if isinstance(dt, datetime):
if oldest is None or dt < oldest: if oldest is None or dt < oldest:
oldest = dt oldest = dt
@@ -562,7 +584,7 @@ class LanceDBStorage:
for row in rows: for row in rows:
sc = str(row.get("scope", "")) sc = str(row.get("scope", ""))
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"): if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
rest = sc[len(prefix):] rest = sc[len(prefix) :]
first_component = rest.split("/", 1)[0] first_component = rest.split("/", 1)[0]
if first_component: if first_component:
children.add(prefix + first_component) children.add(prefix + first_component)
@@ -600,7 +622,7 @@ class LanceDBStorage:
return return
prefix = scope_prefix.rstrip("/") prefix = scope_prefix.rstrip("/")
if prefix: if prefix:
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'") self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
def optimize(self) -> None: def optimize(self) -> None:
"""Compact the table synchronously and refresh the scope index. """Compact the table synchronously and refresh the scope index.

View File

@@ -150,7 +150,11 @@ class Memory:
if isinstance(storage, str): if isinstance(storage, str):
from crewai.memory.storage.lancedb_storage import LanceDBStorage from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage) self._storage = (
LanceDBStorage()
if storage == "lancedb"
else LanceDBStorage(path=storage)
)
else: else:
self._storage = storage self._storage = storage

View File

@@ -100,7 +100,12 @@ class I18N(BaseModel):
def retrieve( def retrieve(
self, self,
kind: Literal[ kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory" "slices",
"errors",
"tools",
"reasoning",
"hierarchical_manager_agent",
"memory",
], ],
key: str, key: str,
) -> str: ) -> str:

View File

@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
A tuple of (type, Field) for use with create_model. A tuple of (type, Field) for use with create_model.
""" """
type_ = _json_schema_to_pydantic_type( type_ = _json_schema_to_pydantic_type(
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions json_schema,
root_schema,
name_=name.title(),
enrich_descriptions=enrich_descriptions,
) )
is_required = name in required is_required = name in required
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
if ref: if ref:
ref_schema = _resolve_ref(ref, root_schema) ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type( return _json_schema_to_pydantic_type(
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions ref_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
) )
enum_values = json_schema.get("enum") enum_values = json_schema.get("enum")
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
if all_of_schemas: if all_of_schemas:
if len(all_of_schemas) == 1: if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type( return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_, all_of_schemas[0],
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions, enrich_descriptions=enrich_descriptions,
) )
merged = _merge_all_of_schemas(all_of_schemas, root_schema) merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type( return _json_schema_to_pydantic_type(
merged, root_schema, name_=name_, merged,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions, enrich_descriptions=enrich_descriptions,
) )
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
items_schema = json_schema.get("items") items_schema = json_schema.get("items")
if items_schema: if items_schema:
item_type = _json_schema_to_pydantic_type( item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_, items_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions, enrich_descriptions=enrich_descriptions,
) )
return list[item_type] # type: ignore[valid-type] return list[item_type] # type: ignore[valid-type]
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
if json_schema_.get("title") is None: if json_schema_.get("title") is None:
json_schema_["title"] = name_ or "DynamicModel" json_schema_["title"] = name_ or "DynamicModel"
return create_model_from_schema( return create_model_from_schema(
json_schema_, root_schema=root_schema, json_schema_,
root_schema=root_schema,
enrich_descriptions=enrich_descriptions, enrich_descriptions=enrich_descriptions,
) )
return dict return dict