chore: linter on previous files

This commit is contained in:
Greyson Lalonde
2026-03-05 19:56:12 -05:00
parent 218084d625
commit 4627c345c9
17 changed files with 244 additions and 131 deletions

View File

@@ -32,10 +32,7 @@ class CrewAgentExecutorMixin:
)
if memory is None or not self.task or getattr(memory, "_read_only", False):
return
if (
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
in output.text
):
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
return
try:
raw = (
@@ -48,6 +45,4 @@ class CrewAgentExecutorMixin:
if extracted:
memory.remember_many(extracted, agent_role=self.agent.role)
except Exception as e:
self.agent._logger.log(
"error", f"Failed to save to memory: {e}"
)
self.agent._logger.log("error", f"Failed to save to memory: {e}")

View File

@@ -8,8 +8,8 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
ToolUsageStartedEvent,
)
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
args_dict, parse_error = parse_tool_call_args(
func_args, func_name, call_id, original_tool
)
if parse_error is not None:
return parse_error

View File

@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
@crewai.command()
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
@click.option(
"-l", "--long", is_flag=True, hidden=True,
"-l",
"--long",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-s", "--short", is_flag=True, hidden=True,
"-s",
"--short",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option(
"-e", "--entities", is_flag=True, hidden=True,
"-e",
"--entities",
is_flag=True,
hidden=True,
help="[Deprecated: use --memory] Reset memory",
)
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@@ -218,7 +227,13 @@ def reset_memories(
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
f
for f, v in [
("--long", long),
("--short", short),
("--entities", entities),
]
if v
]
click.echo(
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
@@ -238,9 +253,7 @@ def reset_memories(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(
memory, knowledge, agent_knowledge, kickoff_outputs, all
)
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)

View File

@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.unified_memory import Memory
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
storage = (
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
)
embedder = None
if embedder_config is not None:
from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(embedder_config)
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
self._memory = (
Memory(storage=storage, embedder=embedder)
if embedder
else Memory(storage=storage)
)
except Exception as e:
self._init_error = str(e)
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
if len(record.content) > 80
else record.content
)
label = (
f"{date_str} "
f"[bold]{record.importance:.1f}[/] "
f"{preview}"
)
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
option_list.add_option(label)
def _populate_recall_list(self) -> None:
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
else m.record.content
)
label = (
f"[bold]\\[{m.score:.2f}][/] "
f"{preview} "
f"[dim]scope={m.record.scope}[/]"
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
)
option_list.add_option(label)
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
lines.append(
f"[dim]Created:[/] "
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
)
lines.append(
f"[dim]Last accessed:[/] "
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
panel = self.query_one("#info-panel", Static)
panel.loading = True
try:
scope = (
self._selected_scope
if self._selected_scope != "/"
else None
)
scope = self._selected_scope if self._selected_scope != "/" else None
loop = asyncio.get_event_loop()
matches = await loop.run_in_executor(
None,
lambda: self._memory.recall(
query, scope=scope, limit=10, depth="deep"
),
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
)
self._recall_matches = matches or []
self._view_mode = "recall"

View File

@@ -95,9 +95,7 @@ def reset_memories_command(
continue
if memory:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Memory has been reset."
)
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for search_path in search_paths:
for root, dirs, files in os.walk(search_path):
dirs[:] = [
d
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
]
if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path)
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
try:
if flow_instance := get_flow_instance(
module_attr
):
if flow_instance := get_flow_instance(module_attr):
flow_instances.append(flow_instance)
except Exception: # noqa: S112
continue

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
import asyncio
import contextvars
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from datetime import datetime
import inspect
import json
@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
max_workers = min(8, len(runnable_tool_calls))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_idx = {
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
pool.submit(
contextvars.copy_context().run,
self._execute_single_native_tool_call,
tool_call,
): idx
for idx, tool_call in enumerate(runnable_tool_calls)
}
ordered_results: list[dict[str, Any] | None] = [None] * len(

View File

@@ -34,6 +34,7 @@ class ConsoleProvider:
```python
from crewai.flow.async_feedback import ConsoleProvider
@human_feedback(
message="Review this:",
provider=ConsoleProvider(),
@@ -46,6 +47,7 @@ class ConsoleProvider:
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):

View File

@@ -188,7 +188,7 @@ def human_feedback(
metadata: dict[str, Any] | None = None,
provider: HumanFeedbackProvider | None = None,
learn: bool = False,
learn_source: str = "hitl"
learn_source: str = "hitl",
) -> Callable[[F], F]:
"""Decorator for Flow methods that require human feedback.
@@ -328,9 +328,7 @@ def human_feedback(
"""Recall past HITL lessons and use LLM to pre-review the output."""
try:
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
matches = flow_instance.memory.recall(
query, source=learn_source
)
matches = flow_instance.memory.recall(query, source=learn_source)
if not matches:
return method_output
@@ -341,7 +339,10 @@ def human_feedback(
lessons=lessons,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_pre_review_system"),
},
{"role": "user", "content": prompt},
]
if getattr(llm_inst, "supports_function_calling", lambda: False)():
@@ -366,7 +367,10 @@ def human_feedback(
feedback=raw_feedback,
)
messages = [
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
{
"role": "system",
"content": _get_hitl_prompt("hitl_distill_system"),
},
{"role": "user", "content": prompt},
]
@@ -487,7 +491,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -507,7 +515,11 @@ def human_feedback(
result = _process_feedback(self, method_output, raw_feedback)
# Distill: extract lessons from output + feedback, store in memory
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
if (
learn
and getattr(self, "memory", None) is not None
and raw_feedback.strip()
):
_distill_and_store_lessons(self, method_output, raw_feedback)
return result
@@ -534,7 +546,7 @@ def human_feedback(
metadata=metadata,
provider=provider,
learn=learn,
learn_source=learn_source
learn_source=learn_source,
)
wrapper.__is_flow_method__ = True

View File

@@ -308,7 +308,9 @@ def analyze_for_save(
return MemoryAnalysis.model_validate(response)
except Exception as e:
_logger.warning(
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
"Memory save analysis failed, using defaults: %s",
e,
exc_info=False,
)
return _SAVE_DEFAULTS
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
return ConsolidationPlan.model_validate(response)
except Exception as e:
_logger.warning(
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
"Consolidation analysis failed, defaulting to insert: %s",
e,
exc_info=False,
)
return _CONSOLIDATION_DEFAULT

View File

@@ -164,7 +164,11 @@ class EncodingFlow(Flow[EncodingState]):
def parallel_find_similar(self) -> None:
"""Search storage for similar records, concurrently for all active items."""
items = list(self.state.items)
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
active = [
(i, item)
for i, item in enumerate(items)
if not item.dropped and item.embedding
]
if not active:
return
@@ -186,7 +190,9 @@ class EncodingFlow(Flow[EncodingState]):
item.top_similarity = float(raw[0][1]) if raw else 0.0
else:
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
futures = [
(i, item, pool.submit(_search_one, item)) for i, item in active
]
for _, item, future in futures:
raw = future.result()
item.similar_records = [r for r, _ in raw]
@@ -251,23 +257,33 @@ class EncodingFlow(Flow[EncodingState]):
self._apply_defaults(item)
consol_futures[i] = pool.submit(
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
elif not fields_provided and not has_similar:
# Group C: field resolution only
save_futures[i] = pool.submit(
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
else:
# Group D: both in parallel
save_futures[i] = pool.submit(
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
consol_futures[i] = pool.submit(
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
# Collect field-resolution results
@@ -339,7 +355,9 @@ class EncodingFlow(Flow[EncodingState]):
# similar_records overlap). Collect one action per record_id, first wins.
# Also build a map from record_id to the original MemoryRecord for updates.
dedup_deletes: set[str] = set() # record_ids to delete
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
dedup_updates: dict[
str, tuple[int, str]
] = {} # record_id -> (item_idx, new_content)
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
for i, item in enumerate(items):
@@ -350,13 +368,24 @@ class EncodingFlow(Flow[EncodingState]):
all_similar[r.id] = r
for action in item.plan.actions:
rid = action.record_id
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
if (
action.action == "delete"
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_deletes.add(rid)
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
elif (
action.action == "update"
and action.new_content
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_updates[rid] = (i, action.new_content)
# --- Batch re-embed all update contents in ONE call ---
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
update_list = list(
dedup_updates.items()
) # [(record_id, (item_idx, new_content)), ...]
update_embeddings: list[list[float]] = []
if update_list:
update_contents = [content for _, (_, content) in update_list]
@@ -377,16 +406,21 @@ class EncodingFlow(Flow[EncodingState]):
if item.dropped or item.plan is None:
continue
if item.plan.insert_new:
to_insert.append((i, MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
)))
to_insert.append(
(
i,
MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
),
)
)
# All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant

View File

@@ -249,9 +249,17 @@ class MemorySlice:
total_records += inf.record_count
all_categories.update(inf.categories)
if inf.oldest_record:
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
oldest = (
inf.oldest_record
if oldest is None
else min(oldest, inf.oldest_record)
)
if inf.newest_record:
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
newest = (
inf.newest_record
if newest is None
else max(newest, inf.newest_record)
)
children.extend(inf.child_scopes)
return ScopeInfo(
path=path,

View File

@@ -103,13 +103,12 @@ class RecallFlow(Flow[RecallState]):
)
# Post-filter by time cutoff
if self.state.time_cutoff and raw:
raw = [
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
]
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
# Privacy filter
if not self.state.include_private and raw:
raw = [
(r, s) for r, s in raw
(r, s)
for r, s in raw
if not r.private or r.source == self.state.source
]
return scope, raw
@@ -130,16 +129,17 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
else:
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
futures = {
pool.submit(_search_one, emb, sc): (emb, sc)
for emb, sc in tasks
pool.submit(_search_one, emb, sc): (emb, sc) for emb, sc in tasks
}
for future in as_completed(futures):
scope, results = future.result()
@@ -147,16 +147,16 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
self.state.chunk_findings = findings
self.state.confidence = max(
(f["top_score"] for f in findings), default=0.0
)
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
return findings
# ------------------------------------------------------------------
@@ -210,12 +210,16 @@ class RecallFlow(Flow[RecallState]):
# Parse time_filter into a datetime cutoff
if analysis.time_filter:
try:
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
self.state.time_cutoff = datetime.fromisoformat(
analysis.time_filter
)
except ValueError:
pass
# Batch-embed all sub-queries in ONE call
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
queries = (
analysis.recall_queries if analysis.recall_queries else [self.state.query]
)
queries = queries[:3]
embeddings = embed_texts(self._embedder, queries)
pairs: list[tuple[str, list[float]]] = [
@@ -296,17 +300,21 @@ class RecallFlow(Flow[RecallState]):
response = self._llm.call([{"role": "user", "content": prompt}])
if isinstance(response, str) and "missing" in response.lower():
self.state.evidence_gaps.append(response[:200])
enhanced.append({
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
}
)
except Exception:
enhanced.append({
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
}
)
self.state.chunk_findings = enhanced
return enhanced

View File

@@ -90,6 +90,7 @@ class LanceDBStorage:
# Raise it proactively so scans on large tables never hit OS error 24.
try:
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft < 4096:
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
@@ -110,7 +111,9 @@ class LanceDBStorage:
# If no table exists yet, defer creation until the first save so the
# dimension can be auto-detected from the embedder's actual output.
try:
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
self._table: lancedb.table.Table | None = self._db.open_table(
self._table_name
)
self._vector_dim: int = self._infer_dim_from_table(self._table)
# Best-effort: create the scope index if it doesn't exist yet.
self._ensure_scope_index()
@@ -171,7 +174,10 @@ class LanceDBStorage:
raise
_logger.debug(
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
op, attempt + 1, _MAX_RETRIES, delay,
op,
attempt + 1,
_MAX_RETRIES,
delay,
)
# Refresh table to pick up the latest version before retrying.
# The next getattr(self._table, op) will use the fresh table.
@@ -280,7 +286,9 @@ class LanceDBStorage:
"last_accessed": record.last_accessed.isoformat(),
"source": record.source or "",
"private": record.private,
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
"vector": record.embedding
if record.embedding
else [0.0] * self._vector_dim,
}
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
@@ -296,7 +304,9 @@ class LanceDBStorage:
id=str(row["id"]),
content=str(row["content"]),
scope=str(row["scope"]),
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
categories=json.loads(row["categories_str"])
if row.get("categories_str")
else [],
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
importance=float(row.get("importance", 0.5)),
created_at=_parse_dt(row.get("created_at")),
@@ -390,13 +400,17 @@ class LanceDBStorage:
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
out: list[tuple[MemoryRecord, float]] = []
for row in results:
record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories):
continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
distance = row.get("_distance", 0.0)
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
@@ -427,9 +441,13 @@ class LanceDBStorage:
to_delete: list[str] = []
for row in rows:
record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories):
if categories and not any(
c in record.categories for c in categories
):
continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
if older_than and record.created_at >= older_than:
continue
@@ -528,7 +546,7 @@ class LanceDBStorage:
for row in rows:
sc = str(row.get("scope", ""))
if child_prefix and sc.startswith(child_prefix):
rest = sc[len(child_prefix):]
rest = sc[len(child_prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
children.add(child_prefix + first_component)
@@ -539,7 +557,11 @@ class LanceDBStorage:
pass
created = row.get("created_at")
if created:
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
dt = (
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
if isinstance(created, str)
else created
)
if isinstance(dt, datetime):
if oldest is None or dt < oldest:
oldest = dt
@@ -562,7 +584,7 @@ class LanceDBStorage:
for row in rows:
sc = str(row.get("scope", ""))
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
rest = sc[len(prefix):]
rest = sc[len(prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
children.add(prefix + first_component)
@@ -600,7 +622,7 @@ class LanceDBStorage:
return
prefix = scope_prefix.rstrip("/")
if prefix:
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
def optimize(self) -> None:
"""Compact the table synchronously and refresh the scope index.

View File

@@ -150,7 +150,11 @@ class Memory:
if isinstance(storage, str):
from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
self._storage = (
LanceDBStorage()
if storage == "lancedb"
else LanceDBStorage(path=storage)
)
else:
self._storage = storage

View File

@@ -100,7 +100,12 @@ class I18N(BaseModel):
def retrieve(
self,
kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
"slices",
"errors",
"tools",
"reasoning",
"hierarchical_manager_agent",
"memory",
],
key: str,
) -> str:

View File

@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
A tuple of (type, Field) for use with create_model.
"""
type_ = _json_schema_to_pydantic_type(
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
json_schema,
root_schema,
name_=name.title(),
enrich_descriptions=enrich_descriptions,
)
is_required = name in required
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
if ref:
ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type(
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
ref_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
enum_values = json_schema.get("enum")
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
if all_of_schemas:
if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_,
all_of_schemas[0],
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type(
merged, root_schema, name_=name_,
merged,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
items_schema = json_schema.get("items")
if items_schema:
item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_,
items_schema,
root_schema,
name_=name_,
enrich_descriptions=enrich_descriptions,
)
return list[item_type] # type: ignore[valid-type]
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
if json_schema_.get("title") is None:
json_schema_["title"] = name_ or "DynamicModel"
return create_model_from_schema(
json_schema_, root_schema=root_schema,
json_schema_,
root_schema=root_schema,
enrich_descriptions=enrich_descriptions,
)
return dict