Source code for inline_snapshot_django

from __future__ import annotations

from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from typing import Any

import sql_impressao
from django.core.signals import request_started
from django.db import DEFAULT_DB_ALIAS, connections, reset_queries
from django.db.backends.utils import logger as sql_logger


[docs] @contextmanager def snapshot_queries( *, using: str | Iterable[str] = "__all__", ) -> Generator[list[str | tuple[str, str]]]: if isinstance(using, str): if using == "__all__": aliases = list(connections) else: aliases = [using] else: aliases = list(using) # State management copied from Django’s CaptureQueriesContext force_debug_cursors = [] for alias in aliases: connection = connections[alias] force_debug_cursors.append(connection.force_debug_cursor) connection.force_debug_cursor = True connection.ensure_connection() reset_queries_disconnected = request_started.disconnect(reset_queries) queries: list[tuple[str, str]] = [] record: list[str | tuple[str, str]] = [] try: with _capture_debug_logged_queries(aliases, queries): yield record finally: if reset_queries_disconnected: request_started.connect(reset_queries) for alias, force_debug_cursor in zip(aliases, force_debug_cursors): connection = connections[alias] connection.force_debug_cursor = force_debug_cursor queries_by_alias = defaultdict(list) for alias, sql in queries: queries_by_alias[alias].append(sql) formatted_queries_by_alias = {} for alias in aliases: if alias not in queries_by_alias: continue # Use sql_impressao to format the SQL queries formatted_queries_by_alias[alias] = deque( sql_impressao.fingerprint_many( queries_by_alias[alias], dialect=vendor_to_dialect.get(connections[alias].vendor, "generic"), ) ) for alias, _ in queries: entry = formatted_queries_by_alias[alias].popleft() if alias != DEFAULT_DB_ALIAS: entry = (alias, entry) record.append(entry)
@contextmanager def _capture_debug_logged_queries( aliases: list[str], queries: list[tuple[str, str]] ) -> Generator[None]: """ Wrap the debug() method of Django’s logger to intercept calls and capture the logged SQL queries. This is done instead of using a custom logging filter to avoid modifying the global logger configuration and to avoid adding logs to test output. """ alias_set = set(aliases) original_debug = sql_logger.debug def debug_wrapper(*args: Any, extra: Any = None, **kwargs: Any) -> Any: if isinstance(extra, dict) and "alias" in extra and "sql" in extra: alias = extra["alias"] sql = extra["sql"] if alias in alias_set: queries.append((alias, sql)) return original_debug(*args, extra=extra, **kwargs) sql_logger.debug = debug_wrapper # type: ignore[method-assign] try: yield finally: sql_logger.debug = original_debug # type: ignore[method-assign] # Map Django database backend 'vendor' strings to sqlparser dialects # (They’re all the same right now…) vendor_to_dialect = { "postgresql": "postgresql", "mysql": "mysql", "oracle": "oracle", "sqlite": "sqlite", }