fix: propagate contextvars across all thread and executor boundaries

This commit is contained in:
Greyson Lalonde
2026-03-12 18:49:24 -04:00
parent d8e38f2f0b
commit 61c83df3b9
19 changed files with 181 additions and 76 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import contextvars
import json
import os
import re
@@ -137,7 +138,9 @@ class StagehandTool(BaseTool):
- 'observe': For finding elements in a specific area
"""
args_schema: type[BaseModel] = StagehandToolSchema
package_dependencies: list[str] = Field(default_factory=lambda: ["stagehand<=0.5.9"])
package_dependencies: list[str] = Field(
default_factory=lambda: ["stagehand<=0.5.9"]
)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
@@ -620,9 +623,12 @@ class StagehandTool(BaseTool):
# We're in an existing event loop, use it
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run, self._async_run(instruction, url, command_type)
ctx.run,
asyncio.run,
self._async_run(instruction, url, command_type),
)
result = future.result()
else:
@@ -706,11 +712,12 @@ class StagehandTool(BaseTool):
if loop.is_running():
import concurrent.futures
ctx = contextvars.copy_context()
with (
concurrent.futures.ThreadPoolExecutor() as executor
):
future = executor.submit(
asyncio.run, self._async_close()
ctx.run, asyncio.run, self._async_close()
)
future.result()
else:

View File

@@ -1,3 +1,4 @@
import contextvars
import threading
from typing import Any
import urllib.request
@@ -66,7 +67,8 @@ def _track_install() -> None:
def _track_install_async() -> None:
"""Track installation in background thread to avoid blocking imports."""
if not Telemetry._is_telemetry_disabled():
thread = threading.Thread(target=_track_install, daemon=True)
ctx = contextvars.copy_context()
thread = threading.Thread(target=ctx.run, args=(_track_install,), daemon=True)
thread.start()

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
import concurrent.futures
import contextvars
from functools import lru_cache
import ssl
import time
@@ -147,8 +148,9 @@ def fetch_agent_card(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)
@@ -215,8 +217,9 @@ def _fetch_agent_card_cached(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)

View File

@@ -7,6 +7,7 @@ import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
import concurrent.futures
from contextlib import asynccontextmanager
import contextvars
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
@@ -229,8 +230,9 @@ def execute_a2a_delegation(
has_running_loop = False
if has_running_loop:
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
return pool.submit(ctx.run, asyncio.run, coro).result()
return asyncio.run(coro)

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Mapping
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from functools import wraps
import json
from types import MethodType
@@ -276,9 +277,10 @@ def _fetch_agent_cards_concurrently(
return agent_cards, failed_agents
max_workers = min(len(a2a_agents), 10)
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_fetch_card_from_config, config): config
executor.submit(ctx.run, _fetch_card_from_config, config): config
for config in a2a_agents
}
for future in as_completed(futures):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Sequence
import contextvars
import shutil
import subprocess
import time
@@ -513,9 +514,13 @@ class Agent(BaseAgent):
"""
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
self._execute_without_timeout, task_prompt=task_prompt, task=task
ctx.run,
self._execute_without_timeout,
task_prompt=task_prompt,
task=task,
)
try:

View File

@@ -1,3 +1,4 @@
import contextvars
import json
from pathlib import Path
import platform
@@ -80,7 +81,10 @@ def run_chat() -> None:
# Start loading indicator
loading_complete = threading.Event()
loading_thread = threading.Thread(target=show_loading, args=(loading_complete,))
ctx = contextvars.copy_context()
loading_thread = threading.Thread(
target=ctx.run, args=(show_loading, loading_complete)
)
loading_thread.start()
try:

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable
import contextvars
from contextvars import ContextVar, Token
from datetime import datetime
import getpass
@@ -509,7 +510,8 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
# Handle all input-related errors silently
result[0] = False
input_thread = threading.Thread(target=get_input, daemon=True)
ctx = contextvars.copy_context()
input_thread = threading.Thread(target=ctx.run, args=(get_input,), daemon=True)
input_thread.start()
input_thread.join(timeout=timeout_seconds)

View File

@@ -17,6 +17,7 @@ from collections.abc import (
ValuesView,
)
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
import copy
import enum
import inspect
@@ -497,7 +498,9 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._list)
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
def index(
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
) -> int: # type: ignore[override]
if stop is None:
return self._list.index(value, start)
return self._list.index(value, start, stop)
@@ -1811,8 +1814,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
try:
asyncio.get_running_loop()
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, _run_flow()).result()
return pool.submit(ctx.run, asyncio.run, _run_flow()).result()
except RuntimeError:
return asyncio.run(_run_flow())
@@ -2236,8 +2240,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
finally:
@@ -2856,8 +2858,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Manual executor management to avoid shutdown(wait=True)
# deadlock when the provider call outlives the timeout.
executor = ThreadPoolExecutor(max_workers=1)
ctx = contextvars.copy_context()
future = executor.submit(
provider.request_input, message, self, metadata
ctx.run, provider.request_input, message, self, metadata
)
try:
raw = future.result(timeout=timeout)

View File

@@ -11,6 +11,7 @@ into a standalone MCPToolResolver. It handles three flavours of MCP reference:
from __future__ import annotations
import asyncio
import contextvars
import time
from typing import TYPE_CHECKING, Any, Final, cast
from urllib.parse import urlparse
@@ -22,10 +23,10 @@ from crewai.mcp.config import (
MCPServerSSE,
MCPServerStdio,
)
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
from crewai.utilities.string_utils import sanitize_tool_name
if TYPE_CHECKING:
@@ -227,7 +228,9 @@ class MCPToolResolver:
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
sanitized_specific_tool = (
sanitize_tool_name(specific_tool) if specific_tool else None
)
try:
tool_schemas = self._get_mcp_tool_schemas(server_params)
@@ -353,9 +356,10 @@ class MCPToolResolver:
asyncio.get_running_loop()
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run, _setup_client_and_list_tools()
ctx.run, asyncio.run, _setup_client_and_list_tools()
)
tools_list = future.result()
except RuntimeError:

View File

@@ -11,6 +11,7 @@ Orchestrates the encoding side of memory in a single Flow with 5 steps:
from __future__ import annotations
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
from datetime import datetime
import math
from typing import Any
@@ -164,7 +165,11 @@ class EncodingFlow(Flow[EncodingState]):
def parallel_find_similar(self) -> None:
"""Search storage for similar records, concurrently for all active items."""
items = list(self.state.items)
active = [(i, item) for i, item in enumerate(items) if not item.dropped and item.embedding]
active = [
(i, item)
for i, item in enumerate(items)
if not item.dropped and item.embedding
]
if not active:
return
@@ -185,8 +190,12 @@ class EncodingFlow(Flow[EncodingState]):
item.similar_records = [r for r, _ in raw]
item.top_similarity = float(raw[0][1]) if raw else 0.0
else:
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
futures = [(i, item, pool.submit(_search_one, item)) for i, item in active]
futures = [
(i, item, pool.submit(ctx.run, _search_one, item))
for i, item in active
]
for _, item, future in futures:
raw = future.result()
item.similar_records = [r for r, _ in raw]
@@ -229,6 +238,7 @@ class EncodingFlow(Flow[EncodingState]):
save_futures: dict[int, Future[MemoryAnalysis]] = {}
consol_futures: dict[int, Future[ConsolidationPlan]] = {}
ctx = contextvars.copy_context()
pool = ThreadPoolExecutor(max_workers=10)
try:
for i, item in enumerate(items):
@@ -250,24 +260,38 @@ class EncodingFlow(Flow[EncodingState]):
# Group B: consolidation only
self._apply_defaults(item)
consol_futures[i] = pool.submit(
ctx.run,
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
elif not fields_provided and not has_similar:
# Group C: field resolution only
save_futures[i] = pool.submit(
ctx.run,
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
else:
# Group D: both in parallel
save_futures[i] = pool.submit(
ctx.run,
analyze_for_save,
item.content, existing_scopes, existing_categories, self._llm,
item.content,
existing_scopes,
existing_categories,
self._llm,
)
consol_futures[i] = pool.submit(
ctx.run,
analyze_for_consolidation,
item.content, list(item.similar_records), self._llm,
item.content,
list(item.similar_records),
self._llm,
)
# Collect field-resolution results
@@ -339,7 +363,9 @@ class EncodingFlow(Flow[EncodingState]):
# similar_records overlap). Collect one action per record_id, first wins.
# Also build a map from record_id to the original MemoryRecord for updates.
dedup_deletes: set[str] = set() # record_ids to delete
dedup_updates: dict[str, tuple[int, str]] = {} # record_id -> (item_idx, new_content)
dedup_updates: dict[
str, tuple[int, str]
] = {} # record_id -> (item_idx, new_content)
all_similar: dict[str, MemoryRecord] = {} # record_id -> MemoryRecord
for i, item in enumerate(items):
@@ -350,13 +376,24 @@ class EncodingFlow(Flow[EncodingState]):
all_similar[r.id] = r
for action in item.plan.actions:
rid = action.record_id
if action.action == "delete" and rid not in dedup_deletes and rid not in dedup_updates:
if (
action.action == "delete"
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_deletes.add(rid)
elif action.action == "update" and action.new_content and rid not in dedup_deletes and rid not in dedup_updates:
elif (
action.action == "update"
and action.new_content
and rid not in dedup_deletes
and rid not in dedup_updates
):
dedup_updates[rid] = (i, action.new_content)
# --- Batch re-embed all update contents in ONE call ---
update_list = list(dedup_updates.items()) # [(record_id, (item_idx, new_content)), ...]
update_list = list(
dedup_updates.items()
) # [(record_id, (item_idx, new_content)), ...]
update_embeddings: list[list[float]] = []
if update_list:
update_contents = [content for _, (_, content) in update_list]
@@ -377,16 +414,21 @@ class EncodingFlow(Flow[EncodingState]):
if item.dropped or item.plan is None:
continue
if item.plan.insert_new:
to_insert.append((i, MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
)))
to_insert.append(
(
i,
MemoryRecord(
content=item.content,
scope=item.resolved_scope,
categories=item.resolved_categories,
metadata=item.resolved_metadata,
importance=item.resolved_importance,
embedding=item.embedding if item.embedding else None,
source=item.resolved_source,
private=item.resolved_private,
),
)
)
# All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant

View File

@@ -11,6 +11,7 @@ Implements adaptive-depth retrieval with:
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from datetime import datetime
from typing import Any
from uuid import uuid4
@@ -103,13 +104,12 @@ class RecallFlow(Flow[RecallState]):
)
# Post-filter by time cutoff
if self.state.time_cutoff and raw:
raw = [
(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff
]
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
# Privacy filter
if not self.state.include_private and raw:
raw = [
(r, s) for r, s in raw
(r, s)
for r, s in raw
if not r.private or r.source == self.state.source
]
return scope, raw
@@ -130,15 +130,18 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
else:
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=min(len(tasks), 4)) as pool:
futures = {
pool.submit(_search_one, emb, sc): (emb, sc)
pool.submit(ctx.run, _search_one, emb, sc): (emb, sc)
for emb, sc in tasks
}
for future in as_completed(futures):
@@ -147,16 +150,16 @@ class RecallFlow(Flow[RecallState]):
top_composite, _ = compute_composite_score(
results[0][0], results[0][1], self._config
)
findings.append({
"scope": scope,
"results": results,
"top_score": top_composite,
})
findings.append(
{
"scope": scope,
"results": results,
"top_score": top_composite,
}
)
self.state.chunk_findings = findings
self.state.confidence = max(
(f["top_score"] for f in findings), default=0.0
)
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
return findings
# ------------------------------------------------------------------
@@ -210,12 +213,16 @@ class RecallFlow(Flow[RecallState]):
# Parse time_filter into a datetime cutoff
if analysis.time_filter:
try:
self.state.time_cutoff = datetime.fromisoformat(analysis.time_filter)
self.state.time_cutoff = datetime.fromisoformat(
analysis.time_filter
)
except ValueError:
pass
# Batch-embed all sub-queries in ONE call
queries = analysis.recall_queries if analysis.recall_queries else [self.state.query]
queries = (
analysis.recall_queries if analysis.recall_queries else [self.state.query]
)
queries = queries[:3]
embeddings = embed_texts(self._embedder, queries)
pairs: list[tuple[str, list[float]]] = [
@@ -296,17 +303,21 @@ class RecallFlow(Flow[RecallState]):
response = self._llm.call([{"role": "user", "content": prompt}])
if isinstance(response, str) and "missing" in response.lower():
self.state.evidence_gaps.append(response[:200])
enhanced.append({
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": response,
"results": finding["results"],
}
)
except Exception:
enhanced.append({
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
})
enhanced.append(
{
"scope": finding["scope"],
"extraction": "",
"results": finding["results"],
}
)
self.state.chunk_findings = enhanced
return enhanced

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from contextlib import AbstractContextManager
import contextvars
from datetime import datetime
import json
import logging
@@ -250,8 +251,10 @@ class LanceDBStorage:
def _compact_async(self) -> None:
"""Fire-and-forget: compact the table in a daemon background thread."""
ctx = contextvars.copy_context()
threading.Thread(
target=self._compact_safe,
target=ctx.run,
args=(self._compact_safe,),
daemon=True,
name="lancedb-compact",
).start()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
from datetime import datetime
import threading
import time
@@ -229,8 +230,12 @@ class Memory(BaseModel):
If the pool has been shut down (e.g. after ``close()``), the save
runs synchronously as a fallback so late saves still succeed.
"""
ctx = contextvars.copy_context()
try:
future: Future[Any] = self._save_pool.submit(fn, *args, **kwargs)
future: Future[Any] = self._save_pool.submit(
ctx.run,
lambda: fn(*args, **kwargs),
)
except RuntimeError:
# Pool shut down -- run synchronously as fallback
future = Future()

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import contextvars
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
@@ -169,8 +170,9 @@ def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
if loop and loop.is_running():
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return pool.submit(ctx.run, asyncio.run, result).result()
return asyncio.run(result)
return result

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import contextvars
from functools import partial
import inspect
from pathlib import Path
@@ -146,8 +147,9 @@ def _resolve_result(result: Any) -> Any:
if loop and loop.is_running():
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return pool.submit(ctx.run, asyncio.run, result).result()
return asyncio.run(result)
return result

View File

@@ -7,6 +7,7 @@ concurrently by the executor.
import asyncio
from collections.abc import Callable
import contextvars
from typing import Any
from crewai.tools import BaseTool
@@ -84,9 +85,10 @@ class MCPNativeTool(BaseTool):
import concurrent.futures
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor() as executor:
coro = self._run_async(**kwargs)
future = executor.submit(asyncio.run, coro)
future = executor.submit(ctx.run, asyncio.run, coro)
return future.result()
except RuntimeError:
return asyncio.run(self._run_async(**kwargs))

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
import concurrent.futures
import contextvars
import inspect
import json
import re
@@ -907,8 +908,9 @@ def summarize_messages(
chunks=chunks, llm=llm, callbacks=callbacks, i18n=i18n
)
if is_inside_event_loop():
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
summarized_contents = pool.submit(asyncio.run, coro).result()
summarized_contents = pool.submit(ctx.run, asyncio.run, coro).result()
else:
summarized_contents = asyncio.run(coro)

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
import concurrent.futures
import contextvars
import logging
from typing import TYPE_CHECKING, TypeVar
from uuid import UUID
@@ -46,8 +47,9 @@ def _run_sync(coro: Coroutine[None, None, T]) -> T:
"""
try:
asyncio.get_running_loop()
ctx = contextvars.copy_context()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(asyncio.run, coro)
future = executor.submit(ctx.run, asyncio.run, coro)
return future.result()
except RuntimeError:
return asyncio.run(coro)