fix: make _resolving_refs thread-safe via threading.local()

Address Bugbot review: replace module-level set with threading.local()
so concurrent schema conversions in ThreadPoolExecutor don't interfere.

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2026-04-15 19:12:39 +00:00
parent ae09793712
commit 1305bfc7ea

View File

@@ -19,6 +19,7 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
import datetime import datetime
import logging import logging
import threading
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union, cast from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union, cast
import uuid import uuid
@@ -666,10 +667,21 @@ def build_rich_field_description(prop_schema: dict[str, Any]) -> str:
return ". ".join(parts) if parts else "" return ". ".join(parts) if parts else ""
# Thread-local set tracking which schemas are currently being converted to # Thread-local storage tracking which ``$ref`` paths are currently being
# Pydantic models. Used by ``_json_schema_to_pydantic_type`` to detect # resolved. Used by ``_json_schema_to_pydantic_type`` to detect circular
# circular ``$ref`` chains and break the recursion with a ``dict`` fallback. # ``$ref`` chains and break the recursion with a ``dict`` fallback.
_resolving_refs: set[str] = set() # Each thread gets its own independent set so concurrent schema conversions
# (e.g. via ThreadPoolExecutor in MCP tool resolution) don't interfere.
_resolving_refs_local = threading.local()
def _get_resolving_refs() -> set[str]:
"""Return the per-thread resolving-refs set, creating it on first access."""
try:
return _resolving_refs_local.refs # type: ignore[no-any-return]
except AttributeError:
_resolving_refs_local.refs = set() # type: ignore[attr-defined]
return _resolving_refs_local.refs # type: ignore[no-any-return]
def _safe_replace_refs(json_schema: dict[str, Any]) -> dict[str, Any]: def _safe_replace_refs(json_schema: dict[str, Any]) -> dict[str, Any]:
@@ -1021,9 +1033,10 @@ def _json_schema_to_pydantic_type(
if ref: if ref:
# Detect circular $ref chains - if we are already resolving this # Detect circular $ref chains - if we are already resolving this
# ref higher up the call stack, break the cycle by returning dict. # ref higher up the call stack, break the cycle by returning dict.
if ref in _resolving_refs: resolving = _get_resolving_refs()
if ref in resolving:
return dict return dict
_resolving_refs.add(ref) resolving.add(ref)
try: try:
ref_schema = _resolve_ref(ref, root_schema) ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type( return _json_schema_to_pydantic_type(
@@ -1033,7 +1046,7 @@ def _json_schema_to_pydantic_type(
enrich_descriptions=enrich_descriptions, enrich_descriptions=enrich_descriptions,
) )
finally: finally:
_resolving_refs.discard(ref) resolving.discard(ref)
enum_values = json_schema.get("enum") enum_values = json_schema.get("enum")
if enum_values: if enum_values: