mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
chore: linter on previous files
This commit is contained in:
@@ -32,10 +32,7 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
if memory is None or not self.task or getattr(memory, "_read_only", False):
|
if memory is None or not self.task or getattr(memory, "_read_only", False):
|
||||||
return
|
return
|
||||||
if (
|
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
|
||||||
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
|
||||||
in output.text
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
raw = (
|
raw = (
|
||||||
@@ -48,6 +45,4 @@ class CrewAgentExecutorMixin:
|
|||||||
if extracted:
|
if extracted:
|
||||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.agent._logger.log(
|
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||||
"error", f"Failed to save to memory: {e}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import contextvars
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import contextvars
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
@@ -895,7 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
ToolUsageStartedEvent,
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool)
|
args_dict, parse_error = parse_tool_call_args(
|
||||||
|
func_args, func_name, call_id, original_tool
|
||||||
|
)
|
||||||
if parse_error is not None:
|
if parse_error is not None:
|
||||||
return parse_error
|
return parse_error
|
||||||
|
|
||||||
|
|||||||
@@ -182,15 +182,24 @@ def log_tasks_outputs() -> None:
|
|||||||
@crewai.command()
|
@crewai.command()
|
||||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||||
@click.option(
|
@click.option(
|
||||||
"-l", "--long", is_flag=True, hidden=True,
|
"-l",
|
||||||
|
"--long",
|
||||||
|
is_flag=True,
|
||||||
|
hidden=True,
|
||||||
help="[Deprecated: use --memory] Reset memory",
|
help="[Deprecated: use --memory] Reset memory",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-s", "--short", is_flag=True, hidden=True,
|
"-s",
|
||||||
|
"--short",
|
||||||
|
is_flag=True,
|
||||||
|
hidden=True,
|
||||||
help="[Deprecated: use --memory] Reset memory",
|
help="[Deprecated: use --memory] Reset memory",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-e", "--entities", is_flag=True, hidden=True,
|
"-e",
|
||||||
|
"--entities",
|
||||||
|
is_flag=True,
|
||||||
|
hidden=True,
|
||||||
help="[Deprecated: use --memory] Reset memory",
|
help="[Deprecated: use --memory] Reset memory",
|
||||||
)
|
)
|
||||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||||
@@ -218,7 +227,13 @@ def reset_memories(
|
|||||||
# Treat legacy flags as --memory with a deprecation warning
|
# Treat legacy flags as --memory with a deprecation warning
|
||||||
if long or short or entities:
|
if long or short or entities:
|
||||||
legacy_used = [
|
legacy_used = [
|
||||||
f for f, v in [("--long", long), ("--short", short), ("--entities", entities)] if v
|
f
|
||||||
|
for f, v in [
|
||||||
|
("--long", long),
|
||||||
|
("--short", short),
|
||||||
|
("--entities", entities),
|
||||||
|
]
|
||||||
|
if v
|
||||||
]
|
]
|
||||||
click.echo(
|
click.echo(
|
||||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||||
@@ -238,9 +253,7 @@ def reset_memories(
|
|||||||
"Please specify at least one memory type to reset using the appropriate flags."
|
"Please specify at least one memory type to reset using the appropriate flags."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
reset_memories_command(
|
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||||
memory, knowledge, agent_knowledge, kickoff_outputs, all
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||||
|
|
||||||
|
|||||||
@@ -125,13 +125,19 @@ class MemoryTUI(App[None]):
|
|||||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||||
from crewai.memory.unified_memory import Memory
|
from crewai.memory.unified_memory import Memory
|
||||||
|
|
||||||
storage = LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
storage = (
|
||||||
|
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||||
|
)
|
||||||
embedder = None
|
embedder = None
|
||||||
if embedder_config is not None:
|
if embedder_config is not None:
|
||||||
from crewai.rag.embeddings.factory import build_embedder
|
from crewai.rag.embeddings.factory import build_embedder
|
||||||
|
|
||||||
embedder = build_embedder(embedder_config)
|
embedder = build_embedder(embedder_config)
|
||||||
self._memory = Memory(storage=storage, embedder=embedder) if embedder else Memory(storage=storage)
|
self._memory = (
|
||||||
|
Memory(storage=storage, embedder=embedder)
|
||||||
|
if embedder
|
||||||
|
else Memory(storage=storage)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._init_error = str(e)
|
self._init_error = str(e)
|
||||||
|
|
||||||
@@ -200,11 +206,7 @@ class MemoryTUI(App[None]):
|
|||||||
if len(record.content) > 80
|
if len(record.content) > 80
|
||||||
else record.content
|
else record.content
|
||||||
)
|
)
|
||||||
label = (
|
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||||
f"{date_str} "
|
|
||||||
f"[bold]{record.importance:.1f}[/] "
|
|
||||||
f"{preview}"
|
|
||||||
)
|
|
||||||
option_list.add_option(label)
|
option_list.add_option(label)
|
||||||
|
|
||||||
def _populate_recall_list(self) -> None:
|
def _populate_recall_list(self) -> None:
|
||||||
@@ -220,9 +222,7 @@ class MemoryTUI(App[None]):
|
|||||||
else m.record.content
|
else m.record.content
|
||||||
)
|
)
|
||||||
label = (
|
label = (
|
||||||
f"[bold]\\[{m.score:.2f}][/] "
|
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||||
f"{preview} "
|
|
||||||
f"[dim]scope={m.record.scope}[/]"
|
|
||||||
)
|
)
|
||||||
option_list.add_option(label)
|
option_list.add_option(label)
|
||||||
|
|
||||||
@@ -251,8 +251,7 @@ class MemoryTUI(App[None]):
|
|||||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||||
lines.append(
|
lines.append(
|
||||||
f"[dim]Created:[/] "
|
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
f"{record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
)
|
)
|
||||||
lines.append(
|
lines.append(
|
||||||
f"[dim]Last accessed:[/] "
|
f"[dim]Last accessed:[/] "
|
||||||
@@ -362,17 +361,11 @@ class MemoryTUI(App[None]):
|
|||||||
panel = self.query_one("#info-panel", Static)
|
panel = self.query_one("#info-panel", Static)
|
||||||
panel.loading = True
|
panel.loading = True
|
||||||
try:
|
try:
|
||||||
scope = (
|
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||||
self._selected_scope
|
|
||||||
if self._selected_scope != "/"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
matches = await loop.run_in_executor(
|
matches = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: self._memory.recall(
|
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||||
query, scope=scope, limit=10, depth="deep"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self._recall_matches = matches or []
|
self._recall_matches = matches or []
|
||||||
self._view_mode = "recall"
|
self._view_mode = "recall"
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ def reset_memories_command(
|
|||||||
continue
|
continue
|
||||||
if memory:
|
if memory:
|
||||||
_reset_flow_memory(flow)
|
_reset_flow_memory(flow)
|
||||||
click.echo(
|
click.echo(f"[Flow ({flow_name})] Memory has been reset.")
|
||||||
f"[Flow ({flow_name})] Memory has been reset."
|
|
||||||
)
|
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||||
|
|||||||
@@ -442,9 +442,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
|||||||
for search_path in search_paths:
|
for search_path in search_paths:
|
||||||
for root, dirs, files in os.walk(search_path):
|
for root, dirs, files in os.walk(search_path):
|
||||||
dirs[:] = [
|
dirs[:] = [
|
||||||
d
|
d for d in dirs if d not in _SKIP_DIRS and not d.startswith(".")
|
||||||
for d in dirs
|
|
||||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
|
||||||
]
|
]
|
||||||
if flow_path in files and "cli/templates" not in root:
|
if flow_path in files and "cli/templates" not in root:
|
||||||
file_os_path = os.path.join(root, flow_path)
|
file_os_path = os.path.join(root, flow_path)
|
||||||
@@ -464,9 +462,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
|||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
module_attr = getattr(module, attr_name)
|
module_attr = getattr(module, attr_name)
|
||||||
try:
|
try:
|
||||||
if flow_instance := get_flow_instance(
|
if flow_instance := get_flow_instance(module_attr):
|
||||||
module_attr
|
|
||||||
):
|
|
||||||
flow_instances.append(flow_instance)
|
flow_instances.append(flow_instance)
|
||||||
except Exception: # noqa: S112
|
except Exception: # noqa: S112
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import contextvars
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@@ -729,7 +729,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
|||||||
max_workers = min(8, len(runnable_tool_calls))
|
max_workers = min(8, len(runnable_tool_calls))
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
future_to_idx = {
|
future_to_idx = {
|
||||||
pool.submit(contextvars.copy_context().run, self._execute_single_native_tool_call, tool_call): idx
|
pool.submit(
|
||||||
|
contextvars.copy_context().run,
|
||||||
|
self._execute_single_native_tool_call,
|
||||||
|
tool_call,
|
||||||
|
): idx
|
||||||
for idx, tool_call in enumerate(runnable_tool_calls)
|
for idx, tool_call in enumerate(runnable_tool_calls)
|
||||||
}
|
}
|
||||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class ConsoleProvider:
|
|||||||
```python
|
```python
|
||||||
from crewai.flow.async_feedback import ConsoleProvider
|
from crewai.flow.async_feedback import ConsoleProvider
|
||||||
|
|
||||||
|
|
||||||
@human_feedback(
|
@human_feedback(
|
||||||
message="Review this:",
|
message="Review this:",
|
||||||
provider=ConsoleProvider(),
|
provider=ConsoleProvider(),
|
||||||
@@ -46,6 +47,7 @@ class ConsoleProvider:
|
|||||||
```python
|
```python
|
||||||
from crewai.flow import Flow, start
|
from crewai.flow import Flow, start
|
||||||
|
|
||||||
|
|
||||||
class MyFlow(Flow):
|
class MyFlow(Flow):
|
||||||
@start()
|
@start()
|
||||||
def gather_info(self):
|
def gather_info(self):
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ def human_feedback(
|
|||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
provider: HumanFeedbackProvider | None = None,
|
provider: HumanFeedbackProvider | None = None,
|
||||||
learn: bool = False,
|
learn: bool = False,
|
||||||
learn_source: str = "hitl"
|
learn_source: str = "hitl",
|
||||||
) -> Callable[[F], F]:
|
) -> Callable[[F], F]:
|
||||||
"""Decorator for Flow methods that require human feedback.
|
"""Decorator for Flow methods that require human feedback.
|
||||||
|
|
||||||
@@ -328,9 +328,7 @@ def human_feedback(
|
|||||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||||
try:
|
try:
|
||||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||||
matches = flow_instance.memory.recall(
|
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||||
query, source=learn_source
|
|
||||||
)
|
|
||||||
if not matches:
|
if not matches:
|
||||||
return method_output
|
return method_output
|
||||||
|
|
||||||
@@ -341,7 +339,10 @@ def human_feedback(
|
|||||||
lessons=lessons,
|
lessons=lessons,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": _get_hitl_prompt("hitl_pre_review_system")},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": _get_hitl_prompt("hitl_pre_review_system"),
|
||||||
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
if getattr(llm_inst, "supports_function_calling", lambda: False)():
|
||||||
@@ -366,7 +367,10 @@ def human_feedback(
|
|||||||
feedback=raw_feedback,
|
feedback=raw_feedback,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": _get_hitl_prompt("hitl_distill_system")},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": _get_hitl_prompt("hitl_distill_system"),
|
||||||
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -487,7 +491,11 @@ def human_feedback(
|
|||||||
result = _process_feedback(self, method_output, raw_feedback)
|
result = _process_feedback(self, method_output, raw_feedback)
|
||||||
|
|
||||||
# Distill: extract lessons from output + feedback, store in memory
|
# Distill: extract lessons from output + feedback, store in memory
|
||||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
if (
|
||||||
|
learn
|
||||||
|
and getattr(self, "memory", None) is not None
|
||||||
|
and raw_feedback.strip()
|
||||||
|
):
|
||||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -507,7 +515,11 @@ def human_feedback(
|
|||||||
result = _process_feedback(self, method_output, raw_feedback)
|
result = _process_feedback(self, method_output, raw_feedback)
|
||||||
|
|
||||||
# Distill: extract lessons from output + feedback, store in memory
|
# Distill: extract lessons from output + feedback, store in memory
|
||||||
if learn and getattr(self, "memory", None) is not None and raw_feedback.strip():
|
if (
|
||||||
|
learn
|
||||||
|
and getattr(self, "memory", None) is not None
|
||||||
|
and raw_feedback.strip()
|
||||||
|
):
|
||||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -534,7 +546,7 @@ def human_feedback(
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
learn=learn,
|
learn=learn,
|
||||||
learn_source=learn_source
|
learn_source=learn_source,
|
||||||
)
|
)
|
||||||
wrapper.__is_flow_method__ = True
|
wrapper.__is_flow_method__ = True
|
||||||
|
|
||||||
|
|||||||
@@ -308,7 +308,9 @@ def analyze_for_save(
|
|||||||
return MemoryAnalysis.model_validate(response)
|
return MemoryAnalysis.model_validate(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
"Memory save analysis failed, using defaults: %s", e, exc_info=False,
|
"Memory save analysis failed, using defaults: %s",
|
||||||
|
e,
|
||||||
|
exc_info=False,
|
||||||
)
|
)
|
||||||
return _SAVE_DEFAULTS
|
return _SAVE_DEFAULTS
|
||||||
|
|
||||||
@@ -366,6 +368,8 @@ def analyze_for_consolidation(
|
|||||||
return ConsolidationPlan.model_validate(response)
|
return ConsolidationPlan.model_validate(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
"Consolidation analysis failed, defaulting to insert: %s", e, exc_info=False,
|
"Consolidation analysis failed, defaulting to insert: %s",
|
||||||
|
e,
|
||||||
|
exc_info=False,
|
||||||
)
|
)
|
||||||
return _CONSOLIDATION_DEFAULT
|
return _CONSOLIDATION_DEFAULT
|
||||||
|
|||||||
@@ -164,7 +164,11 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
def parallel_find_similar(self) -> None:
|
def parallel_find_similar(self) -> None:
|
||||||
"""Search storage for similar records, concurrently for all active items."""
|
"""Search storage for similar records, concurrently for all active items."""
|
||||||
items = list(self.state.items)
|
items = list(self.state.items)
|
||||||
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
|
active = [
|
||||||
|
(i, item)
|
||||||
|
for i, item in enumerate(items)
|
||||||
|
if not item.dropped and item.embedding
|
||||||
|
]
|
||||||
|
|
||||||
if not active:
|
if not active:
|
||||||
return
|
return
|
||||||
@@ -186,7 +190,9 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
|
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
|
||||||
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
|
futures = [
|
||||||
|
(i, item, pool.submit(_search_one, item)) for i, item in active
|
||||||
|
]
|
||||||
for _, item, future in futures:
|
for _, item, future in futures:
|
||||||
raw = future.result()
|
raw = future.result()
|
||||||
item.similar_records = [r for r, _ in raw]
|
item.similar_records = [r for r, _ in raw]
|
||||||
@@ -251,23 +257,33 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
self._apply_defaults(item)
|
self._apply_defaults(item)
|
||||||
consol_futures[i] = pool.submit(
|
consol_futures[i] = pool.submit(
|
||||||
analyze_for_consolidation,
|
analyze_for_consolidation,
|
||||||
item.content, list(item.similar_records), self._llm,
|
item.content,
|
||||||
|
list(item.similar_records),
|
||||||
|
self._llm,
|
||||||
)
|
)
|
||||||
elif not fields_provided and not has_similar:
|
elif not fields_provided and not has_similar:
|
||||||
# Group C: field resolution only
|
# Group C: field resolution only
|
||||||
save_futures[i] = pool.submit(
|
save_futures[i] = pool.submit(
|
||||||
analyze_for_save,
|
analyze_for_save,
|
||||||
item.content, existing_scopes, existing_categories, self._llm,
|
item.content,
|
||||||
|
existing_scopes,
|
||||||
|
existing_categories,
|
||||||
|
self._llm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Group D: both in parallel
|
# Group D: both in parallel
|
||||||
save_futures[i] = pool.submit(
|
save_futures[i] = pool.submit(
|
||||||
analyze_for_save,
|
analyze_for_save,
|
||||||
item.content, existing_scopes, existing_categories, self._llm,
|
item.content,
|
||||||
|
existing_scopes,
|
||||||
|
existing_categories,
|
||||||
|
self._llm,
|
||||||
)
|
)
|
||||||
consol_futures[i] = pool.submit(
|
consol_futures[i] = pool.submit(
|
||||||
analyze_for_consolidation,
|
analyze_for_consolidation,
|
||||||
item.content, list(item.similar_records), self._llm,
|
item.content,
|
||||||
|
list(item.similar_records),
|
||||||
|
self._llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect field-resolution results
|
# Collect field-resolution results
|
||||||
@@ -339,7 +355,9 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
# similar_records overlap). Collect one action per record_id, first wins.
|
# similar_records overlap). Collect one action per record_id, first wins.
|
||||||
# Also build a map from record_id to the original MemoryRecord for updates.
|
# Also build a map from record_id to the original MemoryRecord for updates.
|
||||||
dedup_deletes: set[str] = set() # record_ids to delete
|
dedup_deletes: set[str] = set() # record_ids to delete
|
||||||
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
|
dedup_updates: dict[
|
||||||
|
str, tuple[int, str]
|
||||||
|
] = {} # record_id -> (item_idx, new_content)
|
||||||
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
|
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
|
||||||
|
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
@@ -350,13 +368,24 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
all_similar[r.id] = r
|
all_similar[r.id] = r
|
||||||
for action in item.plan.actions:
|
for action in item.plan.actions:
|
||||||
rid = action.record_id
|
rid = action.record_id
|
||||||
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
|
if (
|
||||||
|
action.action == "delete"
|
||||||
|
and rid not in dedup_deletes
|
||||||
|
and rid not in dedup_updates
|
||||||
|
):
|
||||||
dedup_deletes.add(rid)
|
dedup_deletes.add(rid)
|
||||||
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
|
elif (
|
||||||
|
action.action == "update"
|
||||||
|
and action.new_content
|
||||||
|
and rid not in dedup_deletes
|
||||||
|
and rid not in dedup_updates
|
||||||
|
):
|
||||||
dedup_updates[rid] = (i, action.new_content)
|
dedup_updates[rid] = (i, action.new_content)
|
||||||
|
|
||||||
# --- Batch re-embed all update contents in ONE call ---
|
# --- Batch re-embed all update contents in ONE call ---
|
||||||
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
|
update_list = list(
|
||||||
|
dedup_updates.items()
|
||||||
|
) # [(record_id, (item_idx, new_content)), ...]
|
||||||
update_embeddings: list[list[float]] = []
|
update_embeddings: list[list[float]] = []
|
||||||
if update_list:
|
if update_list:
|
||||||
update_contents = [content for _, (_, content) in update_list]
|
update_contents = [content for _, (_, content) in update_list]
|
||||||
@@ -377,16 +406,21 @@ class EncodingFlow(Flow[EncodingState]):
|
|||||||
if item.dropped or item.plan is None:
|
if item.dropped or item.plan is None:
|
||||||
continue
|
continue
|
||||||
if item.plan.insert_new:
|
if item.plan.insert_new:
|
||||||
to_insert.append((i, MemoryRecord(
|
to_insert.append(
|
||||||
content=item.content,
|
(
|
||||||
scope=item.resolved_scope,
|
i,
|
||||||
categories=item.resolved_categories,
|
MemoryRecord(
|
||||||
metadata=item.resolved_metadata,
|
content=item.content,
|
||||||
importance=item.resolved_importance,
|
scope=item.resolved_scope,
|
||||||
embedding=item.embedding if item.embedding else None,
|
categories=item.resolved_categories,
|
||||||
source=item.resolved_source,
|
metadata=item.resolved_metadata,
|
||||||
private=item.resolved_private,
|
importance=item.resolved_importance,
|
||||||
)))
|
embedding=item.embedding if item.embedding else None,
|
||||||
|
source=item.resolved_source,
|
||||||
|
private=item.resolved_private,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# All storage mutations under one lock so no other pipeline can
|
# All storage mutations under one lock so no other pipeline can
|
||||||
# interleave and cause version conflicts. The lock is reentrant
|
# interleave and cause version conflicts. The lock is reentrant
|
||||||
|
|||||||
@@ -249,9 +249,17 @@ class MemorySlice:
|
|||||||
total_records += inf.record_count
|
total_records += inf.record_count
|
||||||
all_categories.update(inf.categories)
|
all_categories.update(inf.categories)
|
||||||
if inf.oldest_record:
|
if inf.oldest_record:
|
||||||
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
|
oldest = (
|
||||||
|
inf.oldest_record
|
||||||
|
if oldest is None
|
||||||
|
else min(oldest, inf.oldest_record)
|
||||||
|
)
|
||||||
if inf.newest_record:
|
if inf.newest_record:
|
||||||
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
|
newest = (
|
||||||
|
inf.newest_record
|
||||||
|
if newest is None
|
||||||
|
else max(newest, inf.newest_record)
|
||||||
|
)
|
||||||
children.extend(inf.child_scopes)
|
children.extend(inf.child_scopes)
|
||||||
return ScopeInfo(
|
return ScopeInfo(
|
||||||
path=path,
|
path=path,
|
||||||
|
|||||||
@@ -103,13 +103,12 @@ class RecallFlow(Flow[RecallState]):
|
|||||||
)
|
)
|
||||||
# Post-filter by time cutoff
|
# Post-filter by time cutoff
|
||||||
if self.state.time_cutoff and raw:
|
if self.state.time_cutoff and raw:
|
||||||
raw = [
|
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
|
||||||
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
|
|
||||||
]
|
|
||||||
# Privacy filter
|
# Privacy filter
|
||||||
if not self.state.include_private and raw:
|
if not self.state.include_private and raw:
|
||||||
raw = [
|
raw = [
|
||||||
(r, s) for r, s in raw
|
(r, s)
|
||||||
|
for r, s in raw
|
||||||
if not r.private or r.source == self.state.source
|
if not r.private or r.source == self.state.source
|
||||||
]
|
]
|
||||||
return scope, raw
|
return scope, raw
|
||||||
@@ -130,16 +129,17 @@ class RecallFlow(Flow[RecallState]):
|
|||||||
top_composite, _ = compute_composite_score(
|
top_composite, _ = compute_composite_score(
|
||||||
results[0][0], results[0][1], self._config
|
results[0][0], results[0][1], self._config
|
||||||
)
|
)
|
||||||
findings.append({
|
findings.append(
|
||||||
"scope": scope,
|
{
|
||||||
"results": results,
|
"scope": scope,
|
||||||
"top_score": top_composite,
|
"results": results,
|
||||||
})
|
"top_score": top_composite,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
|
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
|
||||||
futures = {
|
futures = {
|
||||||
pool.submit(_search_one, emb, sc): (emb, sc)
|
pool.submit(_search_one, emb, sc): (emb, sc) for emb, sc in tasks
|
||||||
for emb, sc in tasks
|
|
||||||
}
|
}
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
scope, results = future.result()
|
scope, results = future.result()
|
||||||
@@ -147,16 +147,16 @@ class RecallFlow(Flow[RecallState]):
|
|||||||
top_composite, _ = compute_composite_score(
|
top_composite, _ = compute_composite_score(
|
||||||
results[0][0], results[0][1], self._config
|
results[0][0], results[0][1], self._config
|
||||||
)
|
)
|
||||||
findings.append({
|
findings.append(
|
||||||
"scope": scope,
|
{
|
||||||
"results": results,
|
"scope": scope,
|
||||||
"top_score": top_composite,
|
"results": results,
|
||||||
})
|
"top_score": top_composite,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.state.chunk_findings = findings
|
self.state.chunk_findings = findings
|
||||||
self.state.confidence = max(
|
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
|
||||||
(f["top_score"] for f in findings), default=0.0
|
|
||||||
)
|
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -210,12 +210,16 @@ class RecallFlow(Flow[RecallState]):
|
|||||||
# Parse time_filter into a datetime cutoff
|
# Parse time_filter into a datetime cutoff
|
||||||
if analysis.time_filter:
|
if analysis.time_filter:
|
||||||
try:
|
try:
|
||||||
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
|
self.state.time_cutoff = datetime.fromisoformat(
|
||||||
|
analysis.time_filter
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Batch-embed all sub-queries in ONE call
|
# Batch-embed all sub-queries in ONE call
|
||||||
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
queries = (
|
||||||
|
analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||||
|
)
|
||||||
queries = queries[:3]
|
queries = queries[:3]
|
||||||
embeddings = embed_texts(self._embedder, queries)
|
embeddings = embed_texts(self._embedder, queries)
|
||||||
pairs: list[tuple[str, list[float]]] = [
|
pairs: list[tuple[str, list[float]]] = [
|
||||||
@@ -296,17 +300,21 @@ class RecallFlow(Flow[RecallState]):
|
|||||||
response = self._llm.call([{"role": "user", "content": prompt}])
|
response = self._llm.call([{"role": "user", "content": prompt}])
|
||||||
if isinstance(response, str) and "missing" in response.lower():
|
if isinstance(response, str) and "missing" in response.lower():
|
||||||
self.state.evidence_gaps.append(response[:200])
|
self.state.evidence_gaps.append(response[:200])
|
||||||
enhanced.append({
|
enhanced.append(
|
||||||
"scope": finding["scope"],
|
{
|
||||||
"extraction": response,
|
"scope": finding["scope"],
|
||||||
"results": finding["results"],
|
"extraction": response,
|
||||||
})
|
"results": finding["results"],
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
enhanced.append({
|
enhanced.append(
|
||||||
"scope": finding["scope"],
|
{
|
||||||
"extraction": "",
|
"scope": finding["scope"],
|
||||||
"results": finding["results"],
|
"extraction": "",
|
||||||
})
|
"results": finding["results"],
|
||||||
|
}
|
||||||
|
)
|
||||||
self.state.chunk_findings = enhanced
|
self.state.chunk_findings = enhanced
|
||||||
return enhanced
|
return enhanced
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class LanceDBStorage:
|
|||||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||||
try:
|
try:
|
||||||
import resource
|
import resource
|
||||||
|
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
if soft < 4096:
|
if soft < 4096:
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||||
@@ -110,7 +111,9 @@ class LanceDBStorage:
|
|||||||
# If no table exists yet, defer creation until the first save so the
|
# If no table exists yet, defer creation until the first save so the
|
||||||
# dimension can be auto-detected from the embedder's actual output.
|
# dimension can be auto-detected from the embedder's actual output.
|
||||||
try:
|
try:
|
||||||
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
|
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||||
|
self._table_name
|
||||||
|
)
|
||||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||||
# Best-effort: create the scope index if it doesn't exist yet.
|
# Best-effort: create the scope index if it doesn't exist yet.
|
||||||
self._ensure_scope_index()
|
self._ensure_scope_index()
|
||||||
@@ -171,7 +174,10 @@ class LanceDBStorage:
|
|||||||
raise
|
raise
|
||||||
_logger.debug(
|
_logger.debug(
|
||||||
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
||||||
op, attempt + 1, _MAX_RETRIES, delay,
|
op,
|
||||||
|
attempt + 1,
|
||||||
|
_MAX_RETRIES,
|
||||||
|
delay,
|
||||||
)
|
)
|
||||||
# Refresh table to pick up the latest version before retrying.
|
# Refresh table to pick up the latest version before retrying.
|
||||||
# The next getattr(self._table, op) will use the fresh table.
|
# The next getattr(self._table, op) will use the fresh table.
|
||||||
@@ -280,7 +286,9 @@ class LanceDBStorage:
|
|||||||
"last_accessed": record.last_accessed.isoformat(),
|
"last_accessed": record.last_accessed.isoformat(),
|
||||||
"source": record.source or "",
|
"source": record.source or "",
|
||||||
"private": record.private,
|
"private": record.private,
|
||||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
"vector": record.embedding
|
||||||
|
if record.embedding
|
||||||
|
else [0.0] * self._vector_dim,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||||
@@ -296,7 +304,9 @@ class LanceDBStorage:
|
|||||||
id=str(row["id"]),
|
id=str(row["id"]),
|
||||||
content=str(row["content"]),
|
content=str(row["content"]),
|
||||||
scope=str(row["scope"]),
|
scope=str(row["scope"]),
|
||||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
categories=json.loads(row["categories_str"])
|
||||||
|
if row.get("categories_str")
|
||||||
|
else [],
|
||||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||||
importance=float(row.get("importance", 0.5)),
|
importance=float(row.get("importance", 0.5)),
|
||||||
created_at=_parse_dt(row.get("created_at")),
|
created_at=_parse_dt(row.get("created_at")),
|
||||||
@@ -390,13 +400,17 @@ class LanceDBStorage:
|
|||||||
prefix = scope_prefix.rstrip("/")
|
prefix = scope_prefix.rstrip("/")
|
||||||
like_val = prefix + "%"
|
like_val = prefix + "%"
|
||||||
query = query.where(f"scope LIKE '{like_val}'")
|
query = query.where(f"scope LIKE '{like_val}'")
|
||||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
results = query.limit(
|
||||||
|
limit * 3 if (categories or metadata_filter) else limit
|
||||||
|
).to_list()
|
||||||
out: list[tuple[MemoryRecord, float]] = []
|
out: list[tuple[MemoryRecord, float]] = []
|
||||||
for row in results:
|
for row in results:
|
||||||
record = self._row_to_record(row)
|
record = self._row_to_record(row)
|
||||||
if categories and not any(c in record.categories for c in categories):
|
if categories and not any(c in record.categories for c in categories):
|
||||||
continue
|
continue
|
||||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
if metadata_filter and not all(
|
||||||
|
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
distance = row.get("_distance", 0.0)
|
distance = row.get("_distance", 0.0)
|
||||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||||
@@ -427,9 +441,13 @@ class LanceDBStorage:
|
|||||||
to_delete: list[str] = []
|
to_delete: list[str] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
record = self._row_to_record(row)
|
record = self._row_to_record(row)
|
||||||
if categories and not any(c in record.categories for c in categories):
|
if categories and not any(
|
||||||
|
c in record.categories for c in categories
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
if metadata_filter and not all(
|
||||||
|
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if older_than and record.created_at >= older_than:
|
if older_than and record.created_at >= older_than:
|
||||||
continue
|
continue
|
||||||
@@ -528,7 +546,7 @@ class LanceDBStorage:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
sc = str(row.get("scope", ""))
|
sc = str(row.get("scope", ""))
|
||||||
if child_prefix and sc.startswith(child_prefix):
|
if child_prefix and sc.startswith(child_prefix):
|
||||||
rest = sc[len(child_prefix):]
|
rest = sc[len(child_prefix) :]
|
||||||
first_component = rest.split("/", 1)[0]
|
first_component = rest.split("/", 1)[0]
|
||||||
if first_component:
|
if first_component:
|
||||||
children.add(child_prefix + first_component)
|
children.add(child_prefix + first_component)
|
||||||
@@ -539,7 +557,11 @@ class LanceDBStorage:
|
|||||||
pass
|
pass
|
||||||
created = row.get("created_at")
|
created = row.get("created_at")
|
||||||
if created:
|
if created:
|
||||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
dt = (
|
||||||
|
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||||
|
if isinstance(created, str)
|
||||||
|
else created
|
||||||
|
)
|
||||||
if isinstance(dt, datetime):
|
if isinstance(dt, datetime):
|
||||||
if oldest is None or dt < oldest:
|
if oldest is None or dt < oldest:
|
||||||
oldest = dt
|
oldest = dt
|
||||||
@@ -562,7 +584,7 @@ class LanceDBStorage:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
sc = str(row.get("scope", ""))
|
sc = str(row.get("scope", ""))
|
||||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||||
rest = sc[len(prefix):]
|
rest = sc[len(prefix) :]
|
||||||
first_component = rest.split("/", 1)[0]
|
first_component = rest.split("/", 1)[0]
|
||||||
if first_component:
|
if first_component:
|
||||||
children.add(prefix + first_component)
|
children.add(prefix + first_component)
|
||||||
@@ -600,7 +622,7 @@ class LanceDBStorage:
|
|||||||
return
|
return
|
||||||
prefix = scope_prefix.rstrip("/")
|
prefix = scope_prefix.rstrip("/")
|
||||||
if prefix:
|
if prefix:
|
||||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
|
||||||
|
|
||||||
def optimize(self) -> None:
|
def optimize(self) -> None:
|
||||||
"""Compact the table synchronously and refresh the scope index.
|
"""Compact the table synchronously and refresh the scope index.
|
||||||
|
|||||||
@@ -150,7 +150,11 @@ class Memory:
|
|||||||
if isinstance(storage, str):
|
if isinstance(storage, str):
|
||||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||||
|
|
||||||
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
|
self._storage = (
|
||||||
|
LanceDBStorage()
|
||||||
|
if storage == "lancedb"
|
||||||
|
else LanceDBStorage(path=storage)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._storage = storage
|
self._storage = storage
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,12 @@ class I18N(BaseModel):
|
|||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
kind: Literal[
|
kind: Literal[
|
||||||
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent", "memory"
|
"slices",
|
||||||
|
"errors",
|
||||||
|
"tools",
|
||||||
|
"reasoning",
|
||||||
|
"hierarchical_manager_agent",
|
||||||
|
"memory",
|
||||||
],
|
],
|
||||||
key: str,
|
key: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|||||||
@@ -657,7 +657,10 @@ def _json_schema_to_pydantic_field(
|
|||||||
A tuple of (type, Field) for use with create_model.
|
A tuple of (type, Field) for use with create_model.
|
||||||
"""
|
"""
|
||||||
type_ = _json_schema_to_pydantic_type(
|
type_ = _json_schema_to_pydantic_type(
|
||||||
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
|
json_schema,
|
||||||
|
root_schema,
|
||||||
|
name_=name.title(),
|
||||||
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
is_required = name in required
|
is_required = name in required
|
||||||
|
|
||||||
@@ -806,7 +809,10 @@ def _json_schema_to_pydantic_type(
|
|||||||
if ref:
|
if ref:
|
||||||
ref_schema = _resolve_ref(ref, root_schema)
|
ref_schema = _resolve_ref(ref, root_schema)
|
||||||
return _json_schema_to_pydantic_type(
|
return _json_schema_to_pydantic_type(
|
||||||
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
|
ref_schema,
|
||||||
|
root_schema,
|
||||||
|
name_=name_,
|
||||||
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
enum_values = json_schema.get("enum")
|
enum_values = json_schema.get("enum")
|
||||||
@@ -835,12 +841,16 @@ def _json_schema_to_pydantic_type(
|
|||||||
if all_of_schemas:
|
if all_of_schemas:
|
||||||
if len(all_of_schemas) == 1:
|
if len(all_of_schemas) == 1:
|
||||||
return _json_schema_to_pydantic_type(
|
return _json_schema_to_pydantic_type(
|
||||||
all_of_schemas[0], root_schema, name_=name_,
|
all_of_schemas[0],
|
||||||
|
root_schema,
|
||||||
|
name_=name_,
|
||||||
enrich_descriptions=enrich_descriptions,
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||||
return _json_schema_to_pydantic_type(
|
return _json_schema_to_pydantic_type(
|
||||||
merged, root_schema, name_=name_,
|
merged,
|
||||||
|
root_schema,
|
||||||
|
name_=name_,
|
||||||
enrich_descriptions=enrich_descriptions,
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -858,7 +868,9 @@ def _json_schema_to_pydantic_type(
|
|||||||
items_schema = json_schema.get("items")
|
items_schema = json_schema.get("items")
|
||||||
if items_schema:
|
if items_schema:
|
||||||
item_type = _json_schema_to_pydantic_type(
|
item_type = _json_schema_to_pydantic_type(
|
||||||
items_schema, root_schema, name_=name_,
|
items_schema,
|
||||||
|
root_schema,
|
||||||
|
name_=name_,
|
||||||
enrich_descriptions=enrich_descriptions,
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
return list[item_type] # type: ignore[valid-type]
|
return list[item_type] # type: ignore[valid-type]
|
||||||
@@ -870,7 +882,8 @@ def _json_schema_to_pydantic_type(
|
|||||||
if json_schema_.get("title") is None:
|
if json_schema_.get("title") is None:
|
||||||
json_schema_["title"] = name_ or "DynamicModel"
|
json_schema_["title"] = name_ or "DynamicModel"
|
||||||
return create_model_from_schema(
|
return create_model_from_schema(
|
||||||
json_schema_, root_schema=root_schema,
|
json_schema_,
|
||||||
|
root_schema=root_schema,
|
||||||
enrich_descriptions=enrich_descriptions,
|
enrich_descriptions=enrich_descriptions,
|
||||||
)
|
)
|
||||||
return dict
|
return dict
|
||||||
|
|||||||
Reference in New Issue
Block a user