mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: propagate contextvars across all thread and executor boundaries
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -137,7 +138,9 @@ class StagehandTool(BaseTool):
|
||||
- 'observe': For finding elements in a specific area
|
||||
"""
|
||||
args_schema: type[BaseModel] = StagehandToolSchema
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
|
||||
package_dependencies: list[str] = Field(
|
||||
default_factory=lambda: ["stagehand<=0.5.9"]
|
||||
)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
@@ -620,9 +623,12 @@ class StagehandTool(BaseTool):
|
||||
# We're in an existing event loop, use it
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, self._async_run(instruction, url, command_type)
|
||||
ctx.run,
|
||||
asyncio.run,
|
||||
self._async_run(instruction, url, command_type),
|
||||
)
|
||||
result = future.result()
|
||||
else:
|
||||
@@ -706,11 +712,12 @@ class StagehandTool(BaseTool):
|
||||
if loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor() as executor
|
||||
):
|
||||
future = executor.submit(
|
||||
asyncio.run, self._async_close()
|
||||
ctx.run, asyncio.run, self._async_close()
|
||||
)
|
||||
future.result()
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Any
|
||||
import urllib.request
|
||||
@@ -66,7 +67,8 @@ def _track_install() -> None:
|
||||
def _track_install_async() -> None:
|
||||
"""Track installation in background thread to avoid blocking imports."""
|
||||
if not Telemetry._is_telemetry_disabled():
|
||||
thread = threading.Thread(target=_track_install, daemon=True)
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
@@ -147,8 +148,9 @@ def fetch_agent_card(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@@ -215,8 +217,9 @@ def _fetch_agent_card_cached(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
import concurrent.futures
|
||||
from contextlib import asynccontextmanager
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
@@ -229,8 +230,9 @@ def execute_a2a_delegation(
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
import json
|
||||
from types import MethodType
|
||||
@@ -276,9 +277,10 @@ def _fetch_agent_cards_concurrently(
|
||||
return agent_cards, failed_agents
|
||||
|
||||
max_workers = min(len(a2a_agents), 10)
|
||||
ctx = contextvars.copy_context()
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(_fetch_card_from_config, config): config
|
||||
executor.submit(ctx.run, _fetch_card_from_config, config): config
|
||||
for config in a2a_agents
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Sequence
|
||||
import contextvars
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -513,9 +514,13 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task
|
||||
ctx.run,
|
||||
self._execute_without_timeout,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import json
|
||||
from pathlib import Path
|
||||
import platform
|
||||
@@ -80,7 +81,10 @@ def run_chat() -> None:
|
||||
|
||||
# Start loading indicator
|
||||
loading_complete = threading.Event()
|
||||
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
|
||||
ctx = contextvars.copy_context()
|
||||
loading_thread = threading.Thread(
|
||||
target=ctx.run, args=(show_loading, loading_complete)
|
||||
)
|
||||
loading_thread.start()
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from contextvars import ContextVar, Token
|
||||
from datetime import datetime
|
||||
import getpass
|
||||
@@ -509,7 +510,8 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
# Handle all input-related errors silently
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
ctx = contextvars.copy_context()
|
||||
input_thread = threading.Thread(target=ctx.run, args=(get_input,), daemon=True)
|
||||
input_thread.start()
|
||||
input_thread.join(timeout=timeout_seconds)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from collections.abc import (
|
||||
ValuesView,
|
||||
)
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
import copy
|
||||
import enum
|
||||
import inspect
|
||||
@@ -497,7 +498,9 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._list)
|
||||
|
||||
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
|
||||
def index(
|
||||
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
||||
) -> int: # type: ignore[override]
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
@@ -1811,8 +1814,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, _run_flow()).result()
|
||||
return pool.submit(ctx.run, asyncio.run, _run_flow()).result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_run_flow())
|
||||
|
||||
@@ -2236,8 +2240,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
# Run sync methods in thread pool for isolation
|
||||
# This allows Agent.kickoff() to work synchronously inside Flow methods
|
||||
import contextvars
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
|
||||
finally:
|
||||
@@ -2856,8 +2858,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Manual executor management to avoid shutdown(wait=True)
|
||||
# deadlock when the provider call outlives the timeout.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
ctx = contextvars.copy_context()
|
||||
future = executor.submit(
|
||||
provider.request_input, message, self, metadata
|
||||
ctx.run, provider.request_input, message, self, metadata
|
||||
)
|
||||
try:
|
||||
raw = future.result(timeout=timeout)
|
||||
|
||||
@@ -11,6 +11,7 @@ into a standalone MCPToolResolver. It handles three flavours of MCP reference:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from urllib.parse import urlparse
|
||||
@@ -22,10 +23,10 @@ from crewai.mcp.config import (
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -227,7 +228,9 @@ class MCPToolResolver:
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
|
||||
sanitized_specific_tool = (
|
||||
sanitize_tool_name(specific_tool) if specific_tool else None
|
||||
)
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
@@ -353,9 +356,10 @@ class MCPToolResolver:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
ctx.run, asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
|
||||
@@ -11,6 +11,7 @@ Orchestrates the encoding side of memory in a single Flow with 5 steps:
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import math
|
||||
from typing import Any
|
||||
@@ -164,7 +165,11 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
def parallel_find_similar(self) -> None:
|
||||
"""Search storage for similar records, concurrently for all active items."""
|
||||
items = list(self.state.items)
|
||||
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
|
||||
active = [
|
||||
(i, item)
|
||||
for i, item in enumerate(items)
|
||||
if not item.dropped and item.embedding
|
||||
]
|
||||
|
||||
if not active:
|
||||
return
|
||||
@@ -185,8 +190,12 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
else:
|
||||
ctx = contextvars.copy_context()
|
||||
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
|
||||
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
|
||||
futures = [
|
||||
(i, item, pool.submit(ctx.run, _search_one, item))
|
||||
for i, item in active
|
||||
]
|
||||
for _, item, future in futures:
|
||||
raw = future.result()
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
@@ -229,6 +238,7 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
save_futures: dict[int, Future[MemoryAnalysis]] = {}
|
||||
consol_futures: dict[int, Future[ConsolidationPlan]] = {}
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
pool = ThreadPoolExecutor(max_workers=10)
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
@@ -250,24 +260,38 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
# Group B: consolidation only
|
||||
self._apply_defaults(item)
|
||||
consol_futures[i] = pool.submit(
|
||||
ctx.run,
|
||||
analyze_for_consolidation,
|
||||
item.content, list(item.similar_records), self._llm,
|
||||
item.content,
|
||||
list(item.similar_records),
|
||||
self._llm,
|
||||
)
|
||||
elif not fields_provided and not has_similar:
|
||||
# Group C: field resolution only
|
||||
save_futures[i] = pool.submit(
|
||||
ctx.run,
|
||||
analyze_for_save,
|
||||
item.content, existing_scopes, existing_categories, self._llm,
|
||||
item.content,
|
||||
existing_scopes,
|
||||
existing_categories,
|
||||
self._llm,
|
||||
)
|
||||
else:
|
||||
# Group D: both in parallel
|
||||
save_futures[i] = pool.submit(
|
||||
ctx.run,
|
||||
analyze_for_save,
|
||||
item.content, existing_scopes, existing_categories, self._llm,
|
||||
item.content,
|
||||
existing_scopes,
|
||||
existing_categories,
|
||||
self._llm,
|
||||
)
|
||||
consol_futures[i] = pool.submit(
|
||||
ctx.run,
|
||||
analyze_for_consolidation,
|
||||
item.content, list(item.similar_records), self._llm,
|
||||
item.content,
|
||||
list(item.similar_records),
|
||||
self._llm,
|
||||
)
|
||||
|
||||
# Collect field-resolution results
|
||||
@@ -339,7 +363,9 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
# similar_records overlap). Collect one action per record_id, first wins.
|
||||
# Also build a map from record_id to the original MemoryRecord for updates.
|
||||
dedup_deletes: set[str] = set() # record_ids to delete
|
||||
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
|
||||
dedup_updates: dict[
|
||||
str, tuple[int, str]
|
||||
] = {} # record_id -> (item_idx, new_content)
|
||||
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
|
||||
|
||||
for i, item in enumerate(items):
|
||||
@@ -350,13 +376,24 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
all_similar[r.id] = r
|
||||
for action in item.plan.actions:
|
||||
rid = action.record_id
|
||||
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
|
||||
if (
|
||||
action.action == "delete"
|
||||
and rid not in dedup_deletes
|
||||
and rid not in dedup_updates
|
||||
):
|
||||
dedup_deletes.add(rid)
|
||||
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
|
||||
elif (
|
||||
action.action == "update"
|
||||
and action.new_content
|
||||
and rid not in dedup_deletes
|
||||
and rid not in dedup_updates
|
||||
):
|
||||
dedup_updates[rid] = (i, action.new_content)
|
||||
|
||||
# --- Batch re-embed all update contents in ONE call ---
|
||||
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
|
||||
update_list = list(
|
||||
dedup_updates.items()
|
||||
) # [(record_id, (item_idx, new_content)), ...]
|
||||
update_embeddings: list[list[float]] = []
|
||||
if update_list:
|
||||
update_contents = [content for _, (_, content) in update_list]
|
||||
@@ -377,16 +414,21 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
if item.dropped or item.plan is None:
|
||||
continue
|
||||
if item.plan.insert_new:
|
||||
to_insert.append((i, MemoryRecord(
|
||||
content=item.content,
|
||||
scope=item.resolved_scope,
|
||||
categories=item.resolved_categories,
|
||||
metadata=item.resolved_metadata,
|
||||
importance=item.resolved_importance,
|
||||
embedding=item.embedding if item.embedding else None,
|
||||
source=item.resolved_source,
|
||||
private=item.resolved_private,
|
||||
)))
|
||||
to_insert.append(
|
||||
(
|
||||
i,
|
||||
MemoryRecord(
|
||||
content=item.content,
|
||||
scope=item.resolved_scope,
|
||||
categories=item.resolved_categories,
|
||||
metadata=item.resolved_metadata,
|
||||
importance=item.resolved_importance,
|
||||
embedding=item.embedding if item.embedding else None,
|
||||
source=item.resolved_source,
|
||||
private=item.resolved_private,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# All storage mutations under one lock so no other pipeline can
|
||||
# interleave and cause version conflicts. The lock is reentrant
|
||||
|
||||
@@ -11,6 +11,7 @@ Implements adaptive-depth retrieval with:
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
@@ -103,13 +104,12 @@ class RecallFlow(Flow[RecallState]):
|
||||
)
|
||||
# Post-filter by time cutoff
|
||||
if self.state.time_cutoff and raw:
|
||||
raw = [
|
||||
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
|
||||
]
|
||||
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
|
||||
# Privacy filter
|
||||
if not self.state.include_private and raw:
|
||||
raw = [
|
||||
(r, s) for r, s in raw
|
||||
(r, s)
|
||||
for r, s in raw
|
||||
if not r.private or r.source == self.state.source
|
||||
]
|
||||
return scope, raw
|
||||
@@ -130,15 +130,18 @@ class RecallFlow(Flow[RecallState]):
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
)
|
||||
findings.append({
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
})
|
||||
findings.append(
|
||||
{
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
}
|
||||
)
|
||||
else:
|
||||
ctx = contextvars.copy_context()
|
||||
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
|
||||
futures = {
|
||||
pool.submit(_search_one, emb, sc): (emb, sc)
|
||||
pool.submit(ctx.run, _search_one, emb, sc): (emb, sc)
|
||||
for emb, sc in tasks
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
@@ -147,16 +150,16 @@ class RecallFlow(Flow[RecallState]):
|
||||
top_composite, _ = compute_composite_score(
|
||||
results[0][0], results[0][1], self._config
|
||||
)
|
||||
findings.append({
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
})
|
||||
findings.append(
|
||||
{
|
||||
"scope": scope,
|
||||
"results": results,
|
||||
"top_score": top_composite,
|
||||
}
|
||||
)
|
||||
|
||||
self.state.chunk_findings = findings
|
||||
self.state.confidence = max(
|
||||
(f["top_score"] for f in findings), default=0.0
|
||||
)
|
||||
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
|
||||
return findings
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -210,12 +213,16 @@ class RecallFlow(Flow[RecallState]):
|
||||
# Parse time_filter into a datetime cutoff
|
||||
if analysis.time_filter:
|
||||
try:
|
||||
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
|
||||
self.state.time_cutoff = datetime.fromisoformat(
|
||||
analysis.time_filter
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Batch-embed all sub-queries in ONE call
|
||||
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||
queries = (
|
||||
analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||
)
|
||||
queries = queries[:3]
|
||||
embeddings = embed_texts(self._embedder, queries)
|
||||
pairs: list[tuple[str, list[float]]] = [
|
||||
@@ -296,17 +303,21 @@ class RecallFlow(Flow[RecallState]):
|
||||
response = self._llm.call([{"role": "user", "content": prompt}])
|
||||
if isinstance(response, str) and "missing" in response.lower():
|
||||
self.state.evidence_gaps.append(response[:200])
|
||||
enhanced.append({
|
||||
"scope": finding["scope"],
|
||||
"extraction": response,
|
||||
"results": finding["results"],
|
||||
})
|
||||
enhanced.append(
|
||||
{
|
||||
"scope": finding["scope"],
|
||||
"extraction": response,
|
||||
"results": finding["results"],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
enhanced.append({
|
||||
"scope": finding["scope"],
|
||||
"extraction": "",
|
||||
"results": finding["results"],
|
||||
})
|
||||
enhanced.append(
|
||||
{
|
||||
"scope": finding["scope"],
|
||||
"extraction": "",
|
||||
"results": finding["results"],
|
||||
}
|
||||
)
|
||||
self.state.chunk_findings = enhanced
|
||||
return enhanced
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -250,8 +251,10 @@ class LanceDBStorage:
|
||||
|
||||
def _compact_async(self) -> None:
|
||||
"""Fire-and-forget: compact the table in a daemon background thread."""
|
||||
ctx = contextvars.copy_context()
|
||||
threading.Thread(
|
||||
target=self._compact_safe,
|
||||
target=ctx.run,
|
||||
args=(self._compact_safe,),
|
||||
daemon=True,
|
||||
name="lancedb-compact",
|
||||
).start()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
@@ -229,8 +230,12 @@ class Memory(BaseModel):
|
||||
If the pool has been shut down (e.g. after ``close()``), the save
|
||||
runs synchronously as a fallback so late saves still succeed.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
try:
|
||||
future: Future[Any] = self._save_pool.submit(fn, *args, **kwargs)
|
||||
future: Future[Any] = self._save_pool.submit(
|
||||
ctx.run,
|
||||
lambda: fn(*args, **kwargs),
|
||||
)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
@@ -169,8 +170,9 @@ def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return pool.submit(ctx.run, asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
@@ -146,8 +147,9 @@ def _resolve_result(result: Any) -> Any:
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return pool.submit(ctx.run, asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ concurrently by the executor.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -84,9 +85,10 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
coro = self._run_async(**kwargs)
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
future = executor.submit(ctx.run, asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._run_async(**kwargs))
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
@@ -907,8 +908,9 @@ def summarize_messages(
|
||||
chunks=chunks, llm=llm, callbacks=callbacks, i18n=i18n
|
||||
)
|
||||
if is_inside_event_loop():
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
summarized_contents = pool.submit(asyncio.run, coro).result()
|
||||
summarized_contents = pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
else:
|
||||
summarized_contents = asyncio.run(coro)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
@@ -46,8 +47,9 @@ def _run_sync(coro: Coroutine[None, None, T]) -> T:
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
future = executor.submit(ctx.run, asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
Reference in New Issue
Block a user